Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,22 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
finally:
self.clean_tmpdir()

def _get_task_cache_dir(self, task: BaseTask) -> Path:
"""Generate the default cache directory path for a task.

Args:
task (BaseTask): The task for which to generate the cache directory path.

Returns:
Path: The default cache directory path for the task.
"""
task_params = json.dumps(
vars(task),
sort_keys=True,
default=str
)
return self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}"

def set_task(
self,
task: Optional[BaseTask] = None,
Expand Down Expand Up @@ -846,14 +862,8 @@ def set_task(
f"Setting task {task.task_name} for {self.dataset_name} base dataset..."
)

task_params = json.dumps(
vars(task),
sort_keys=True,
default=str
)

if cache_dir is None:
cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}"
cache_dir = self._get_task_cache_dir(task)
cache_dir.mkdir(parents=True, exist_ok=True)
else:
# Ensure the explicitly provided cache_dir exists
Expand Down Expand Up @@ -905,6 +915,77 @@ def set_task(
task_name=task.task_name,
)

def clear_cache(self) -> None:
"""Clears the entire dataset cache, including global event data and all task caches.

This method removes:
- The global event dataframe cache (global_event_df.parquet)
- All task-specific caches in the tasks/ directory
- Temporary files in the tmp/ directory

After calling this method, the dataset will need to be reprocessed from
scratch the next time it is accessed.

Note:
This operation cannot be undone. Use with caution.
"""
cache_path = self.cache_dir

if cache_path.exists():
logger.info(f"Clearing entire dataset cache at {cache_path}")
try:
shutil.rmtree(cache_path)
logger.info(f"Successfully cleared dataset cache at {cache_path}")

# Reset cached attributes since the cache has been cleared
self._cache_dir = None
self._global_event_df = None
self._unique_patient_ids = None
except Exception as e:
logger.error(f"Failed to clear cache at {cache_path}: {e}")
raise
else:
logger.info(f"No cache found at {cache_path}, nothing to clear")

def clear_task_cache(self, task: Optional[BaseTask] = None) -> None:
"""Clears the default cache directory for a specific task.

This method removes only the default task-specific cache directory for the given task,
preserving the global event dataframe cache and other task caches.

Note that if set_task was called with a custom cache_dir parameter, that cache
will not be cleared by this method. This only clears the default cache location
at {self.cache_dir}/tasks/{task_name}_{uuid5(vars(task))}.

Args:
task (Optional[BaseTask]): The task whose default cache should be cleared.
If None, uses the default task.

Raises:
AssertionError: If no default task is found and task is None.

Note:
This operation cannot be undone. The task cache will need to be
regenerated the next time set_task is called with this task.
"""
if task is None:
assert self.default_task is not None, "No default task found"
task = self.default_task

# Use the same helper method as set_task to ensure consistency
task_cache_dir = self._get_task_cache_dir(task)

if task_cache_dir.exists():
logger.info(f"Clearing task cache for '{task.task_name}' at {task_cache_dir}")
try:
shutil.rmtree(task_cache_dir)
logger.info(f"Successfully cleared task cache for '{task.task_name}'")
except Exception as e:
logger.error(f"Failed to clear task cache at {task_cache_dir}: {e}")
raise
else:
logger.info(f"No cache found for task '{task.task_name}' at {task_cache_dir}, nothing to clear")

def _main_guard(self, func_name: str):
"""Warn if method is accessed from a non-main process."""

Expand Down
103 changes: 103 additions & 0 deletions tests/core/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,109 @@ def test_tasks_with_diff_param_values_get_diff_caches(self):
sample_dataset1.close()
sample_dataset2.close()

def test_clear_cache_removes_all_caches(self):
"""Test that clear_cache removes entire dataset cache."""
# Create cache by accessing global_event_df
_ = self.dataset.global_event_df
cache_path = self.dataset.cache_dir

# Verify cache exists
self.assertTrue(cache_path.exists())
self.assertTrue((cache_path / "global_event_df.parquet").exists())

# Create a task cache
sample_dataset = self.dataset.set_task(self.task)
task_cache_dir = self.dataset._get_task_cache_dir(self.task)
self.assertTrue(task_cache_dir.exists())

# Store the cache path before clearing
cache_path_str = str(cache_path)

# Clear entire cache
self.dataset.clear_cache()

# Verify cache directory is removed (use stored path to avoid recreation)
from pathlib import Path
self.assertFalse(Path(cache_path_str).exists())

# Verify cached attributes are reset
self.assertIsNone(self.dataset._cache_dir)
self.assertIsNone(self.dataset._global_event_df)
self.assertIsNone(self.dataset._unique_patient_ids)

sample_dataset.close()

def test_clear_cache_handles_nonexistent_cache(self):
"""Test that clear_cache handles the case when no cache exists."""
# Create a fresh dataset without any cache
fresh_dataset = MockDataset(cache_dir=self.cache_dir / "fresh")

# This should not raise an error
fresh_dataset.clear_cache()

def test_clear_task_cache_removes_only_specified_task(self):
"""Test that clear_task_cache removes only the specified task cache."""
# Create two different tasks
task1 = MockTask(param=1)
task2 = MockTask(param=2)

# Set both tasks to create their caches
sample_dataset1 = self.dataset.set_task(task1)
sample_dataset2 = self.dataset.set_task(task2)

# Get cache directories and store paths as strings
task1_cache_dir = self.dataset._get_task_cache_dir(task1)
task2_cache_dir = self.dataset._get_task_cache_dir(task2)
global_cache = self.dataset.cache_dir / "global_event_df.parquet"

task1_cache_str = str(task1_cache_dir)
task2_cache_str = str(task2_cache_dir)
global_cache_str = str(global_cache)

# Verify all caches exist
self.assertTrue(task1_cache_dir.exists())
self.assertTrue(task2_cache_dir.exists())
self.assertTrue(global_cache.exists())

# Clear only task1 cache
self.dataset.clear_task_cache(task1)

# Verify task1 cache is removed but others remain (use stored paths)
from pathlib import Path
self.assertFalse(Path(task1_cache_str).exists())
self.assertTrue(Path(task2_cache_str).exists())
self.assertTrue(Path(global_cache_str).exists())

sample_dataset1.close()
sample_dataset2.close()

def test_clear_task_cache_handles_nonexistent_cache(self):
"""Test that clear_task_cache handles the case when task cache doesn't exist."""
task = MockTask(param=999)

# This should not raise an error even though cache doesn't exist
self.dataset.clear_task_cache(task)

def test_get_task_cache_dir_consistency(self):
"""Test that _get_task_cache_dir produces consistent paths."""
task = MockTask(param=42)

# Get path multiple times
path1 = self.dataset._get_task_cache_dir(task)
path2 = self.dataset._get_task_cache_dir(task)

# Should be identical
self.assertEqual(path1, path2)

# Should match the pattern used in set_task
task_params = json.dumps(
vars(task),
sort_keys=True,
default=str
)
expected_path = self.dataset.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}"
self.assertEqual(path1, expected_path)


if __name__ == "__main__":
unittest.main()
Loading