Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,78 @@ def set_task(
task_name=task.task_name,
)

def clear_cache(self) -> None:
"""Clears the entire dataset cache, including global event data and all task caches.

This method removes:
- The global event dataframe cache (global_event_df.parquet)
- All task-specific caches in the tasks/ directory
- Temporary files in the tmp/ directory

After calling this method, the dataset will need to be reprocessed from
scratch the next time it is accessed.

Note:
This operation cannot be undone. Use with caution.
"""
cache_path = self.cache_dir

if cache_path.exists():
logger.info(f"Clearing entire dataset cache at {cache_path}")
try:
shutil.rmtree(cache_path)
logger.info(f"Successfully cleared dataset cache at {cache_path}")

# Reset cached attributes since the cache has been cleared
self._cache_dir = None
self._global_event_df = None
self._unique_patient_ids = None
except Exception as e:
logger.error(f"Failed to clear cache at {cache_path}: {e}")
raise
else:
logger.info(f"No cache found at {cache_path}, nothing to clear")

def clear_task_cache(self, task: Optional[BaseTask] = None) -> None:
"""Clears the cache for a specific task.

This method removes only the task-specific cache directory for the given task,
preserving the global event dataframe cache and other task caches.

Args:
task (Optional[BaseTask]): The task whose cache should be cleared.
If None, uses the default task.

Raises:
AssertionError: If no default task is found and task is None.

Note:
This operation cannot be undone. The task cache will need to be
regenerated the next time set_task is called with this task.
"""
if task is None:
assert self.default_task is not None, "No default task found"
task = self.default_task

# Generate the same task cache directory name as in set_task
task_params = json.dumps(
vars(task),
sort_keys=True,
default=str
)
task_cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}"

if task_cache_dir.exists():
logger.info(f"Clearing task cache for '{task.task_name}' at {task_cache_dir}")
try:
shutil.rmtree(task_cache_dir)
logger.info(f"Successfully cleared task cache for '{task.task_name}'")
except Exception as e:
logger.error(f"Failed to clear task cache at {task_cache_dir}: {e}")
raise
else:
logger.info(f"No cache found for task '{task.task_name}' at {task_cache_dir}, nothing to clear")

def _main_guard(self, func_name: str):
"""Warn if method is accessed from a non-main process."""

Expand Down