Skip to content

Commit a408701

Browse files
committed
Add clear_cache and clear_task_cache methods to BaseDataset
Implements #765 by adding two new methods to manage cache cleanup: - clear_cache(): Clears entire dataset cache including global event dataframe and all task caches - clear_task_cache(task=None): Clears only the specified task's cache while preserving global event cache and other task caches Both methods handle non-existent caches gracefully and provide comprehensive logging.
1 parent 82011b0 commit a408701

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

pyhealth/datasets/base_dataset.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,78 @@ def set_task(
905905
task_name=task.task_name,
906906
)
907907

908+
def clear_cache(self) -> None:
909+
"""Clears the entire dataset cache, including global event data and all task caches.
910+
911+
This method removes:
912+
- The global event dataframe cache (global_event_df.parquet)
913+
- All task-specific caches in the tasks/ directory
914+
- Temporary files in the tmp/ directory
915+
916+
After calling this method, the dataset will need to be reprocessed from
917+
scratch the next time it is accessed.
918+
919+
Note:
920+
This operation cannot be undone. Use with caution.
921+
"""
922+
cache_path = self.cache_dir
923+
924+
if cache_path.exists():
925+
logger.info(f"Clearing entire dataset cache at {cache_path}")
926+
try:
927+
shutil.rmtree(cache_path)
928+
logger.info(f"Successfully cleared dataset cache at {cache_path}")
929+
930+
# Reset cached attributes since the cache has been cleared
931+
self._cache_dir = None
932+
self._global_event_df = None
933+
self._unique_patient_ids = None
934+
except Exception as e:
935+
logger.error(f"Failed to clear cache at {cache_path}: {e}")
936+
raise
937+
else:
938+
logger.info(f"No cache found at {cache_path}, nothing to clear")
939+
940+
def clear_task_cache(self, task: Optional[BaseTask] = None) -> None:
941+
"""Clears the cache for a specific task.
942+
943+
This method removes only the task-specific cache directory for the given task,
944+
preserving the global event dataframe cache and other task caches.
945+
946+
Args:
947+
task (Optional[BaseTask]): The task whose cache should be cleared.
948+
If None, uses the default task.
949+
950+
Raises:
951+
AssertionError: If no default task is found and task is None.
952+
953+
Note:
954+
This operation cannot be undone. The task cache will need to be
955+
regenerated the next time set_task is called with this task.
956+
"""
957+
if task is None:
958+
assert self.default_task is not None, "No default task found"
959+
task = self.default_task
960+
961+
# Generate the same task cache directory name as in set_task
962+
task_params = json.dumps(
963+
vars(task),
964+
sort_keys=True,
965+
default=str
966+
)
967+
task_cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}"
968+
969+
if task_cache_dir.exists():
970+
logger.info(f"Clearing task cache for '{task.task_name}' at {task_cache_dir}")
971+
try:
972+
shutil.rmtree(task_cache_dir)
973+
logger.info(f"Successfully cleared task cache for '{task.task_name}'")
974+
except Exception as e:
975+
logger.error(f"Failed to clear task cache at {task_cache_dir}: {e}")
976+
raise
977+
else:
978+
logger.info(f"No cache found for task '{task.task_name}' at {task_cache_dir}, nothing to clear")
979+
908980
def _main_guard(self, func_name: str):
909981
"""Warn if method is accessed from a non-main process."""
910982

0 commit comments

Comments
 (0)