diff --git a/releso/shape_parameterization.py b/releso/shape_parameterization.py index 2982658..9b14a54 100644 --- a/releso/shape_parameterization.py +++ b/releso/shape_parameterization.py @@ -229,8 +229,10 @@ def apply_continuous_action(self, value: float) -> float: """ delta = self.max_value - self.min_value descaled_value = ((value + 1.0) / 2.0) * delta - self.current_position = np.clip( - descaled_value + self.min_value, self.min_value, self.max_value + self.current_position = float( + np.clip( + descaled_value + self.min_value, self.min_value, self.max_value + ) ) return self.current_position diff --git a/releso/util/caching.py b/releso/util/caching.py index 7e5d963..9e449b8 100644 --- a/releso/util/caching.py +++ b/releso/util/caching.py @@ -6,6 +6,50 @@ import numpy as np +class CachedValue: + """A simple data structure to hold cached values.""" + + def __init__(self, value: list[str], error: int | bool = False): + self.value: list[str] = value + self.error: int | bool = error + + def __len__(self): + return len(self.value) + + def __getitem__(self, index: int): + """Default dunder functionality. + + Args: + index (int): Index to retrieve from the cached value. + + Returns: + Any: The requested item if found, otherwise raises an IndexError. + """ + return self.value[index] + + def get(self, index: int | list[int]): + """Retrieve items from the cached value. + + Args: + index (int | list[int]): Index or list of indices to retrieve from the + cached value. + + Returns: + Any: The requested item if found. + """ + if isinstance(index, list): + return [self.value[i] for i in index] + return self.value[index] + + def __bool__(self): + raise RuntimeError( + "CachedValue cannot be used in a boolean context. Use the 'error' attribute to check for errors." + ) + + def __iter__(self): + return iter(self.value) + + class RelesoSporCache: """A simple SQLite cache for storing SPOR results. @@ -43,6 +87,7 @@ def main(args, logger, func_data): func_data["cache"] = RelesoSporCache( db_path="spor_cache.db", example_data={"data_key1": "value1", "data_key2": "value2"}, + # error_handling=True, # optional, if you want to store error information as well ) if args.reset: # don't touch the cache, only reset the other func_data entries @@ -51,6 +96,9 @@ def main(args, logger, func_data): key = RelesoSporCache.make_cache_key(input_params) cached_value = func_data["cache"].get(key) if cached_value is not None: + if cached_value: # was an error thrown? + # an error is indicated for this result + pass # Use cached value pass else: @@ -62,13 +110,31 @@ def main(args, logger, func_data): key, {"data_key1": result[0], "data_key2": result[1]}, ) + # if error handling is enabled you can add the error information as well + # func_data["cache"].set( + # key, + # {"data_key1": result[0], "data_key2": result[1]}, + # error=error_occurred, + # ) + You can use error handling to also store if an error occurred during the computation. + This can be useful to avoid repeated attempts of expensive computations that are likely + to fail. If error handling is enabled, you can set the "error" key in the value + dictionary to indicate if an error occurred. When retrieving from the cache, you can + check the "error" key to see if the cached value is valid or if it indicates a previous + error. + Params: db_path (str): Path to the SQLite database file. - This file will be created if it does not exist. - The cache will store key-value pairs where keys are strings and values are JSON-serializable - dictionaries. + This file will be created if it does not exist. + The cache will store key-value pairs where keys are strings and values are + JSON-serializable dictionaries. + example_data (dict): An example dictionary that defines the structure of the + values to be stored in the cache. The keys of this dictionary will be used + as the columns in the SQLite table. + error_handling (bool): If True, caching will also store and retrieve info if an + error occurred. Default is False. """ @staticmethod @@ -94,18 +160,29 @@ def make_cache_key( key = json.dumps(normalized_args, sort_keys=True) return key - def __init__(self, db_path: str, example_data: dict): + def __init__( + self, db_path: str, example_data: dict, error_handling: bool = False + ): self.db_path = db_path + self.error_handling = error_handling - self.keys = example_data.keys() + self.keys = list(example_data.keys()) self.value_accessor = ", ".join(f"{key}" for key in self.keys) + self.value_type = ["TEXT" for _ in self.keys] self.value_question_mark = ", ".join("?" for _ in self.keys) + if error_handling: + self.keys.append("error") + self.value_accessor += ", error" + self.value_type.append("INT") + self.value_question_mark += ", ?" self._initialize_db() def _initialize_db(self): """Initialize the SQLite database and create the cache table if it doesn't exist.""" - table_definitions = ", ".join(f"{key} TEXT" for key in self.keys) + table_definitions = ", ".join( + f"{key} {type_}" for key, type_ in zip(self.keys, self.value_type) + ) with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute( @@ -113,27 +190,84 @@ def _initialize_db(self): ) conn.commit() - def set(self, key: str, value: dict): + def _convert_value_to_database_values(self, value: dict) -> list[any]: + """Convert a value dictionary to a list of database values. + + Args: + value (dict): The value dictionary to convert. + + Returns: + list[any]: A list of values corresponding to the database columns. + """ + ret_values = [ + json.dumps(value[key]) for key in self.keys if key != "error" + ] + if self.error_handling: + ret_values.append(value.get("error", value.get("error", 0))) + return ret_values + + def set(self, key: str, value: dict, error: bool | int = False): """Store a value in the cache. + If error handling is enabled, this method will also ensure that the "error" key is present + in the value dictionary. If the value is None and an error is indicated, it will create a + new dictionary with all keys set to -1 except for the "error" key, which will be set to + -1. This allows for storing error information even when the actual value is not available. + Args: key (str): The key to store the value under. value (dict): The value to store, must be a dictionary with keys matching the example data keys. + error (bool | int): Indicates if an error occurred during computation. Default is False. """ - if not isinstance(value, dict) or value.keys() != self.keys: + if self.error_handling: + # if error handling is enabled, we need to ensure that the "error" key is present in the value dictionary + # if the value is None, we will create a new dictionary with the "error" key set to the error value and all + # other keys set to -1. This allows us to store error information even when the actual value is not available. + if value is None: + if bool(error): + value = {} + for k in self.keys: + if k != "error": + value[k] = -1 + else: + raise ValueError( + f"Value must be a dictionary with keys: {self.keys}" + ) + # add error value to the value dictionary if it is not already present + if value.get("error") is None: + value["error"] = int(error) + # check if all values are present and if the keys match the example data keys + if not isinstance(value, dict) or list(value.keys()) != self.keys: raise ValueError( - f"Value must be a dictionary with keys: {self.keys}" + f"Value must be a dictionary with keys: {self.keys}, but has keys: {list(value.keys())}" ) with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute( f"INSERT OR REPLACE INTO spor_cache (key, {self.value_accessor}) VALUES (?, {self.value_question_mark})", - (key, *[json.dumps(value[key]) for key in self.keys]), + (key, *self._convert_value_to_database_values(value)), ) conn.commit() - def get(self, key: str) -> list[dict] | None: + def _cached_value_from_row(self, row) -> CachedValue: + """Convert a database row to a CachedValue instance. + + Args: + row (tuple): A tuple representing a row from the database. + + Returns: + CachedValue: An instance of CachedValue containing the cached data. + """ + print(row) + if self.error_handling: + value = [json.loads(r) for r in row[:-1]] + else: + value = [json.loads(r) for r in row] + error_value = row[-1] if self.error_handling else False + return CachedValue(value=value, error=error_value) + + def get(self, key: str) -> CachedValue | None: """Retrieve a value from the cache. Args: @@ -150,7 +284,7 @@ def get(self, key: str) -> list[dict] | None: f"SELECT {self.value_accessor} FROM spor_cache WHERE key = ?", (key,), ) - rows = cursor.fetchone() - return [json.loads(row) for row in rows] if rows else None + row = cursor.fetchone() + return self._cached_value_from_row(row) if row else None except sqlite3.OperationalError: return None