diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ec721e8c..09bc0a01 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -792,6 +792,22 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> finally: self.clean_tmpdir() + def _get_task_cache_dir(self, task: BaseTask) -> Path: + """Generate the default cache directory path for a task. + + Args: + task (BaseTask): The task for which to generate the cache directory path. + + Returns: + Path: The default cache directory path for the task. + """ + task_params = json.dumps( + vars(task), + sort_keys=True, + default=str + ) + return self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + def set_task( self, task: Optional[BaseTask] = None, @@ -846,14 +862,8 @@ def set_task( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) - task_params = json.dumps( - vars(task), - sort_keys=True, - default=str - ) - if cache_dir is None: - cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + cache_dir = self._get_task_cache_dir(task) cache_dir.mkdir(parents=True, exist_ok=True) else: # Ensure the explicitly provided cache_dir exists @@ -905,6 +915,77 @@ def set_task( task_name=task.task_name, ) + def clear_cache(self) -> None: + """Clears the entire dataset cache, including global event data and all task caches. + + This method removes: + - The global event dataframe cache (global_event_df.parquet) + - All task-specific caches in the tasks/ directory + - Temporary files in the tmp/ directory + + After calling this method, the dataset will need to be reprocessed from + scratch the next time it is accessed. + + Note: + This operation cannot be undone. Use with caution. + """ + cache_path = self.cache_dir + + if cache_path.exists(): + logger.info(f"Clearing entire dataset cache at {cache_path}") + try: + shutil.rmtree(cache_path) + logger.info(f"Successfully cleared dataset cache at {cache_path}") + + # Reset cached attributes since the cache has been cleared + self._cache_dir = None + self._global_event_df = None + self._unique_patient_ids = None + except Exception as e: + logger.error(f"Failed to clear cache at {cache_path}: {e}") + raise + else: + logger.info(f"No cache found at {cache_path}, nothing to clear") + + def clear_task_cache(self, task: Optional[BaseTask] = None) -> None: + """Clears the default cache directory for a specific task. + + This method removes only the default task-specific cache directory for the given task, + preserving the global event dataframe cache and other task caches. + + Note that if set_task was called with a custom cache_dir parameter, that cache + will not be cleared by this method. This only clears the default cache location + at {self.cache_dir}/tasks/{task_name}_{uuid5(vars(task))}. + + Args: + task (Optional[BaseTask]): The task whose default cache should be cleared. + If None, uses the default task. + + Raises: + AssertionError: If no default task is found and task is None. + + Note: + This operation cannot be undone. The task cache will need to be + regenerated the next time set_task is called with this task. + """ + if task is None: + assert self.default_task is not None, "No default task found" + task = self.default_task + + # Use the same helper method as set_task to ensure consistency + task_cache_dir = self._get_task_cache_dir(task) + + if task_cache_dir.exists(): + logger.info(f"Clearing task cache for '{task.task_name}' at {task_cache_dir}") + try: + shutil.rmtree(task_cache_dir) + logger.info(f"Successfully cleared task cache for '{task.task_name}'") + except Exception as e: + logger.error(f"Failed to clear task cache at {task_cache_dir}: {e}") + raise + else: + logger.info(f"No cache found for task '{task.task_name}' at {task_cache_dir}, nothing to clear") + def _main_guard(self, func_name: str): """Warn if method is accessed from a non-main process.""" diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index fa832c31..18d9a03a 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -225,6 +225,109 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): sample_dataset1.close() sample_dataset2.close() + def test_clear_cache_removes_all_caches(self): + """Test that clear_cache removes entire dataset cache.""" + # Create cache by accessing global_event_df + _ = self.dataset.global_event_df + cache_path = self.dataset.cache_dir + + # Verify cache exists + self.assertTrue(cache_path.exists()) + self.assertTrue((cache_path / "global_event_df.parquet").exists()) + + # Create a task cache + sample_dataset = self.dataset.set_task(self.task) + task_cache_dir = self.dataset._get_task_cache_dir(self.task) + self.assertTrue(task_cache_dir.exists()) + + # Store the cache path before clearing + cache_path_str = str(cache_path) + + # Clear entire cache + self.dataset.clear_cache() + + # Verify cache directory is removed (use stored path to avoid recreation) + from pathlib import Path + self.assertFalse(Path(cache_path_str).exists()) + + # Verify cached attributes are reset + self.assertIsNone(self.dataset._cache_dir) + self.assertIsNone(self.dataset._global_event_df) + self.assertIsNone(self.dataset._unique_patient_ids) + + sample_dataset.close() + + def test_clear_cache_handles_nonexistent_cache(self): + """Test that clear_cache handles the case when no cache exists.""" + # Create a fresh dataset without any cache + fresh_dataset = MockDataset(cache_dir=self.cache_dir / "fresh") + + # This should not raise an error + fresh_dataset.clear_cache() + + def test_clear_task_cache_removes_only_specified_task(self): + """Test that clear_task_cache removes only the specified task cache.""" + # Create two different tasks + task1 = MockTask(param=1) + task2 = MockTask(param=2) + + # Set both tasks to create their caches + sample_dataset1 = self.dataset.set_task(task1) + sample_dataset2 = self.dataset.set_task(task2) + + # Get cache directories and store paths as strings + task1_cache_dir = self.dataset._get_task_cache_dir(task1) + task2_cache_dir = self.dataset._get_task_cache_dir(task2) + global_cache = self.dataset.cache_dir / "global_event_df.parquet" + + task1_cache_str = str(task1_cache_dir) + task2_cache_str = str(task2_cache_dir) + global_cache_str = str(global_cache) + + # Verify all caches exist + self.assertTrue(task1_cache_dir.exists()) + self.assertTrue(task2_cache_dir.exists()) + self.assertTrue(global_cache.exists()) + + # Clear only task1 cache + self.dataset.clear_task_cache(task1) + + # Verify task1 cache is removed but others remain (use stored paths) + from pathlib import Path + self.assertFalse(Path(task1_cache_str).exists()) + self.assertTrue(Path(task2_cache_str).exists()) + self.assertTrue(Path(global_cache_str).exists()) + + sample_dataset1.close() + sample_dataset2.close() + + def test_clear_task_cache_handles_nonexistent_cache(self): + """Test that clear_task_cache handles the case when task cache doesn't exist.""" + task = MockTask(param=999) + + # This should not raise an error even though cache doesn't exist + self.dataset.clear_task_cache(task) + + def test_get_task_cache_dir_consistency(self): + """Test that _get_task_cache_dir produces consistent paths.""" + task = MockTask(param=42) + + # Get path multiple times + path1 = self.dataset._get_task_cache_dir(task) + path2 = self.dataset._get_task_cache_dir(task) + + # Should be identical + self.assertEqual(path1, path2) + + # Should match the pattern used in set_task + task_params = json.dumps( + vars(task), + sort_keys=True, + default=str + ) + expected_path = self.dataset.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + self.assertEqual(path1, expected_path) + if __name__ == "__main__": unittest.main()