Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions releso/shape_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ def apply_continuous_action(self, value: float) -> float:
"""
delta = self.max_value - self.min_value
descaled_value = ((value + 1.0) / 2.0) * delta
self.current_position = np.clip(
descaled_value + self.min_value, self.min_value, self.max_value
self.current_position = float(
np.clip(
descaled_value + self.min_value, self.min_value, self.max_value
)
)
return self.current_position

Expand Down
160 changes: 147 additions & 13 deletions releso/util/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,50 @@
import numpy as np


class CachedValue:
"""A simple data structure to hold cached values."""

def __init__(self, value: list[str], error: int | bool = False):
self.value: list[str] = value
self.error: int | bool = error

def __len__(self):
return len(self.value)

def __getitem__(self, index: int):
"""Default dunder functionality.

Args:
index (int): Index to retrieve from the cached value.

Returns:
Any: The requested item if found, otherwise raises an IndexError.
"""
return self.value[index]

def get(self, index: int | list[int]):
"""Retrieve items from the cached value.

Args:
index (int | list[int]): Index or list of indices to retrieve from the
cached value.

Returns:
Any: The requested item if found.
"""
if isinstance(index, list):
return [self.value[i] for i in index]
return self.value[index]

def __bool__(self):
raise RuntimeError(
"CachedValue cannot be used in a boolean context. Use the 'error' attribute to check for errors."
)

def __iter__(self):
return iter(self.value)


class RelesoSporCache:
"""A simple SQLite cache for storing SPOR results.

Expand Down Expand Up @@ -43,6 +87,7 @@ def main(args, logger, func_data):
func_data["cache"] = RelesoSporCache(
db_path="spor_cache.db",
example_data={"data_key1": "value1", "data_key2": "value2"},
# error_handling=True, # optional, if you want to store error information as well
)
if args.reset:
# don't touch the cache, only reset the other func_data entries
Expand All @@ -51,6 +96,9 @@ def main(args, logger, func_data):
key = RelesoSporCache.make_cache_key(input_params)
cached_value = func_data["cache"].get(key)
if cached_value is not None:
if cached_value: # was an error thrown?
# an error is indicated for this result
pass
# Use cached value
pass
else:
Expand All @@ -62,13 +110,31 @@ def main(args, logger, func_data):
key,
{"data_key1": result[0], "data_key2": result[1]},
)
# if error handling is enabled you can add the error information as well
# func_data["cache"].set(
# key,
# {"data_key1": result[0], "data_key2": result[1]},
# error=error_occurred,
# )


You can use error handling to also store if an error occurred during the computation.
This can be useful to avoid repeated attempts of expensive computations that are likely
to fail. If error handling is enabled, you can set the "error" key in the value
dictionary to indicate if an error occurred. When retrieving from the cache, you can
check the "error" key to see if the cached value is valid or if it indicates a previous
error.

Params:
db_path (str): Path to the SQLite database file.
This file will be created if it does not exist.
The cache will store key-value pairs where keys are strings and values are JSON-serializable
dictionaries.
This file will be created if it does not exist.
The cache will store key-value pairs where keys are strings and values are
JSON-serializable dictionaries.
example_data (dict): An example dictionary that defines the structure of the
values to be stored in the cache. The keys of this dictionary will be used
as the columns in the SQLite table.
error_handling (bool): If True, caching will also store and retrieve info if an
error occurred. Default is False.
"""

@staticmethod
Expand All @@ -94,46 +160,114 @@ def make_cache_key(
key = json.dumps(normalized_args, sort_keys=True)
return key

def __init__(self, db_path: str, example_data: dict):
def __init__(
self, db_path: str, example_data: dict, error_handling: bool = False
):
self.db_path = db_path
self.error_handling = error_handling

self.keys = example_data.keys()
self.keys = list(example_data.keys())
self.value_accessor = ", ".join(f"{key}" for key in self.keys)
self.value_type = ["TEXT" for _ in self.keys]
self.value_question_mark = ", ".join("?" for _ in self.keys)
if error_handling:
self.keys.append("error")
self.value_accessor += ", error"
self.value_type.append("INT")
self.value_question_mark += ", ?"

self._initialize_db()

def _initialize_db(self):
"""Initialize the SQLite database and create the cache table if it doesn't exist."""
table_definitions = ", ".join(f"{key} TEXT" for key in self.keys)
table_definitions = ", ".join(
f"{key} {type_}" for key, type_ in zip(self.keys, self.value_type)
)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
f"CREATE TABLE IF NOT EXISTS spor_cache (key TEXT PRIMARY KEY, {table_definitions})"
)
conn.commit()

def set(self, key: str, value: dict):
def _convert_value_to_database_values(self, value: dict) -> list[any]:
"""Convert a value dictionary to a list of database values.

Args:
value (dict): The value dictionary to convert.

Returns:
list[any]: A list of values corresponding to the database columns.
"""
ret_values = [
json.dumps(value[key]) for key in self.keys if key != "error"
]
if self.error_handling:
ret_values.append(value.get("error", value.get("error", 0)))
return ret_values

def set(self, key: str, value: dict, error: bool | int = False):
"""Store a value in the cache.

If error handling is enabled, this method will also ensure that the "error" key is present
in the value dictionary. If the value is None and an error is indicated, it will create a
new dictionary with all keys set to -1 except for the "error" key, which will be set to
-1. This allows for storing error information even when the actual value is not available.

Args:
key (str): The key to store the value under.
value (dict): The value to store, must be a dictionary with keys matching
the example data keys.
error (bool | int): Indicates if an error occurred during computation. Default is False.
"""
if not isinstance(value, dict) or value.keys() != self.keys:
if self.error_handling:
# if error handling is enabled, we need to ensure that the "error" key is present in the value dictionary
# if the value is None, we will create a new dictionary with the "error" key set to the error value and all
# other keys set to -1. This allows us to store error information even when the actual value is not available.
if value is None:
if bool(error):
value = {}
for k in self.keys:
if k != "error":
value[k] = -1
else:
raise ValueError(
f"Value must be a dictionary with keys: {self.keys}"
)
# add error value to the value dictionary if it is not already present
if value.get("error") is None:
value["error"] = int(error)
# check if all values are present and if the keys match the example data keys
if not isinstance(value, dict) or list(value.keys()) != self.keys:
raise ValueError(
f"Value must be a dictionary with keys: {self.keys}"
f"Value must be a dictionary with keys: {self.keys}, but has keys: {list(value.keys())}"
)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
f"INSERT OR REPLACE INTO spor_cache (key, {self.value_accessor}) VALUES (?, {self.value_question_mark})",
(key, *[json.dumps(value[key]) for key in self.keys]),
(key, *self._convert_value_to_database_values(value)),
)
conn.commit()

def get(self, key: str) -> list[dict] | None:
def _cached_value_from_row(self, row) -> CachedValue:
"""Convert a database row to a CachedValue instance.

Args:
row (tuple): A tuple representing a row from the database.

Returns:
CachedValue: An instance of CachedValue containing the cached data.
"""
print(row)
if self.error_handling:
value = [json.loads(r) for r in row[:-1]]
else:
value = [json.loads(r) for r in row]
error_value = row[-1] if self.error_handling else False
return CachedValue(value=value, error=error_value)

def get(self, key: str) -> CachedValue | None:
"""Retrieve a value from the cache.

Args:
Expand All @@ -150,7 +284,7 @@ def get(self, key: str) -> list[dict] | None:
f"SELECT {self.value_accessor} FROM spor_cache WHERE key = ?",
(key,),
)
rows = cursor.fetchone()
return [json.loads(row) for row in rows] if rows else None
row = cursor.fetchone()
return self._cached_value_from_row(row) if row else None
except sqlite3.OperationalError:
return None