From a408701a4c8b2a5581e11d3406aac154bf466d39 Mon Sep 17 00:00:00 2001 From: Eshaan Date: Mon, 5 Jan 2026 01:57:32 +1100 Subject: [PATCH 1/3] Add clear_cache and clear_task_cache methods to BaseDataset Implements #765 by adding two new methods to manage cache cleanup: - clear_cache(): Clears entire dataset cache including global event dataframe and all task caches - clear_task_cache(task=None): Clears only the specified task's cache while preserving global event cache and other task caches Both methods handle non-existent caches gracefully and provide comprehensive logging. --- pyhealth/datasets/base_dataset.py | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ec721e8c..ce54cf91 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -905,6 +905,78 @@ def set_task( task_name=task.task_name, ) + def clear_cache(self) -> None: + """Clears the entire dataset cache, including global event data and all task caches. + + This method removes: + - The global event dataframe cache (global_event_df.parquet) + - All task-specific caches in the tasks/ directory + - Temporary files in the tmp/ directory + + After calling this method, the dataset will need to be reprocessed from + scratch the next time it is accessed. + + Note: + This operation cannot be undone. Use with caution. + """ + cache_path = self.cache_dir + + if cache_path.exists(): + logger.info(f"Clearing entire dataset cache at {cache_path}") + try: + shutil.rmtree(cache_path) + logger.info(f"Successfully cleared dataset cache at {cache_path}") + + # Reset cached attributes since the cache has been cleared + self._cache_dir = None + self._global_event_df = None + self._unique_patient_ids = None + except Exception as e: + logger.error(f"Failed to clear cache at {cache_path}: {e}") + raise + else: + logger.info(f"No cache found at {cache_path}, nothing to clear") + + def clear_task_cache(self, task: Optional[BaseTask] = None) -> None: + """Clears the cache for a specific task. + + This method removes only the task-specific cache directory for the given task, + preserving the global event dataframe cache and other task caches. + + Args: + task (Optional[BaseTask]): The task whose cache should be cleared. + If None, uses the default task. + + Raises: + AssertionError: If no default task is found and task is None. + + Note: + This operation cannot be undone. The task cache will need to be + regenerated the next time set_task is called with this task. + """ + if task is None: + assert self.default_task is not None, "No default task found" + task = self.default_task + + # Generate the same task cache directory name as in set_task + task_params = json.dumps( + vars(task), + sort_keys=True, + default=str + ) + task_cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + + if task_cache_dir.exists(): + logger.info(f"Clearing task cache for '{task.task_name}' at {task_cache_dir}") + try: + shutil.rmtree(task_cache_dir) + logger.info(f"Successfully cleared task cache for '{task.task_name}'") + except Exception as e: + logger.error(f"Failed to clear task cache at {task_cache_dir}: {e}") + raise + else: + logger.info(f"No cache found for task '{task.task_name}' at {task_cache_dir}, nothing to clear") + def _main_guard(self, func_name: str): """Warn if method is accessed from a non-main process.""" From f054d37e10776f8fac65e16633653f9f3fcdd25c Mon Sep 17 00:00:00 2001 From: Eshaan Date: Tue, 6 Jan 2026 21:31:30 +1100 Subject: [PATCH 2/3] Address PR feedback: refactor task cache path generation and add tests - Extract task cache directory path generation into _get_task_cache_dir() helper method for consistency between set_task() and clear_task_cache() - Update clear_task_cache() to use the helper method - Clarify in docstring that clear_task_cache() only clears default cache location, not custom cache_dir paths - Add 5 comprehensive tests to tests/core/test_caching.py: * test_clear_cache_removes_all_caches * test_clear_cache_handles_nonexistent_cache * test_clear_task_cache_removes_only_specified_task * test_clear_task_cache_handles_nonexistent_cache * test_get_task_cache_dir_consistency Addresses review feedback from @Logiquo and @EricSchrock on PR #770 --- pyhealth/datasets/base_dataset.py | 43 ++++++++------ tests/core/test_caching.py | 94 +++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 17 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ce54cf91..09bc0a01 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -792,6 +792,22 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> finally: self.clean_tmpdir() + def _get_task_cache_dir(self, task: BaseTask) -> Path: + """Generate the default cache directory path for a task. + + Args: + task (BaseTask): The task for which to generate the cache directory path. + + Returns: + Path: The default cache directory path for the task. + """ + task_params = json.dumps( + vars(task), + sort_keys=True, + default=str + ) + return self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + def set_task( self, task: Optional[BaseTask] = None, @@ -846,14 +862,8 @@ def set_task( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) - task_params = json.dumps( - vars(task), - sort_keys=True, - default=str - ) - if cache_dir is None: - cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + cache_dir = self._get_task_cache_dir(task) cache_dir.mkdir(parents=True, exist_ok=True) else: # Ensure the explicitly provided cache_dir exists @@ -938,13 +948,17 @@ def clear_cache(self) -> None: logger.info(f"No cache found at {cache_path}, nothing to clear") def clear_task_cache(self, task: Optional[BaseTask] = None) -> None: - """Clears the cache for a specific task. + """Clears the default cache directory for a specific task. - This method removes only the task-specific cache directory for the given task, + This method removes only the default task-specific cache directory for the given task, preserving the global event dataframe cache and other task caches. + Note that if set_task was called with a custom cache_dir parameter, that cache + will not be cleared by this method. This only clears the default cache location + at {self.cache_dir}/tasks/{task_name}_{uuid5(vars(task))}. + Args: - task (Optional[BaseTask]): The task whose cache should be cleared. + task (Optional[BaseTask]): The task whose default cache should be cleared. If None, uses the default task. Raises: @@ -958,13 +972,8 @@ def clear_task_cache(self, task: Optional[BaseTask] = None) -> None: assert self.default_task is not None, "No default task found" task = self.default_task - # Generate the same task cache directory name as in set_task - task_params = json.dumps( - vars(task), - sort_keys=True, - default=str - ) - task_cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + # Use the same helper method as set_task to ensure consistency + task_cache_dir = self._get_task_cache_dir(task) if task_cache_dir.exists(): logger.info(f"Clearing task cache for '{task.task_name}' at {task_cache_dir}") diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index fa832c31..e4a30624 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -225,6 +225,100 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): sample_dataset1.close() sample_dataset2.close() + def test_clear_cache_removes_all_caches(self): + """Test that clear_cache removes entire dataset cache.""" + # Create cache by accessing global_event_df + _ = self.dataset.global_event_df + cache_path = self.dataset.cache_dir + + # Verify cache exists + self.assertTrue(cache_path.exists()) + self.assertTrue((cache_path / "global_event_df.parquet").exists()) + + # Create a task cache + sample_dataset = self.dataset.set_task(self.task) + task_cache_dir = self.dataset._get_task_cache_dir(self.task) + self.assertTrue(task_cache_dir.exists()) + + # Clear entire cache + self.dataset.clear_cache() + + # Verify everything is removed + self.assertFalse(cache_path.exists()) + + # Verify cached attributes are reset + self.assertIsNone(self.dataset._cache_dir) + self.assertIsNone(self.dataset._global_event_df) + self.assertIsNone(self.dataset._unique_patient_ids) + + sample_dataset.close() + + def test_clear_cache_handles_nonexistent_cache(self): + """Test that clear_cache handles the case when no cache exists.""" + # Create a fresh dataset without any cache + fresh_dataset = MockDataset(cache_dir=self.cache_dir / "fresh") + + # This should not raise an error + fresh_dataset.clear_cache() + + def test_clear_task_cache_removes_only_specified_task(self): + """Test that clear_task_cache removes only the specified task cache.""" + # Create two different tasks + task1 = MockTask(param=1) + task2 = MockTask(param=2) + + # Set both tasks to create their caches + sample_dataset1 = self.dataset.set_task(task1) + sample_dataset2 = self.dataset.set_task(task2) + + # Get cache directories + task1_cache_dir = self.dataset._get_task_cache_dir(task1) + task2_cache_dir = self.dataset._get_task_cache_dir(task2) + global_cache = self.dataset.cache_dir / "global_event_df.parquet" + + # Verify all caches exist + self.assertTrue(task1_cache_dir.exists()) + self.assertTrue(task2_cache_dir.exists()) + self.assertTrue(global_cache.exists()) + + # Clear only task1 cache + self.dataset.clear_task_cache(task1) + + # Verify task1 cache is removed but others remain + self.assertFalse(task1_cache_dir.exists()) + self.assertTrue(task2_cache_dir.exists()) + self.assertTrue(global_cache.exists()) + + sample_dataset1.close() + sample_dataset2.close() + + def test_clear_task_cache_handles_nonexistent_cache(self): + """Test that clear_task_cache handles the case when task cache doesn't exist.""" + task = MockTask(param=999) + + # This should not raise an error even though cache doesn't exist + self.dataset.clear_task_cache(task) + + def test_get_task_cache_dir_consistency(self): + """Test that _get_task_cache_dir produces consistent paths.""" + task = MockTask(param=42) + + # Get path multiple times + path1 = self.dataset._get_task_cache_dir(task) + path2 = self.dataset._get_task_cache_dir(task) + + # Should be identical + self.assertEqual(path1, path2) + + # Should match the pattern used in set_task + task_params = json.dumps( + vars(task), + sort_keys=True, + default=str + ) + expected_path = self.dataset.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + self.assertEqual(path1, expected_path) + if __name__ == "__main__": unittest.main() From 646351ddad5afa7fc0401dfddafdc4e87a561f5d Mon Sep 17 00:00:00 2001 From: Eshaan Date: Wed, 7 Jan 2026 03:01:30 +1100 Subject: [PATCH 3/3] Fix cache clearing tests to avoid path recreation Store cache paths as strings before calling clear methods to prevent the cache_dir property from recreating directories when accessed after clearing. This ensures the tests correctly verify that caches are removed. Fixes CI test failures in test_clear_cache_removes_all_caches and test_clear_task_cache_removes_only_specified_task. --- tests/core/test_caching.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index e4a30624..18d9a03a 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -240,11 +240,15 @@ def test_clear_cache_removes_all_caches(self): task_cache_dir = self.dataset._get_task_cache_dir(self.task) self.assertTrue(task_cache_dir.exists()) + # Store the cache path before clearing + cache_path_str = str(cache_path) + # Clear entire cache self.dataset.clear_cache() - # Verify everything is removed - self.assertFalse(cache_path.exists()) + # Verify cache directory is removed (use stored path to avoid recreation) + from pathlib import Path + self.assertFalse(Path(cache_path_str).exists()) # Verify cached attributes are reset self.assertIsNone(self.dataset._cache_dir) @@ -271,11 +275,15 @@ def test_clear_task_cache_removes_only_specified_task(self): sample_dataset1 = self.dataset.set_task(task1) sample_dataset2 = self.dataset.set_task(task2) - # Get cache directories + # Get cache directories and store paths as strings task1_cache_dir = self.dataset._get_task_cache_dir(task1) task2_cache_dir = self.dataset._get_task_cache_dir(task2) global_cache = self.dataset.cache_dir / "global_event_df.parquet" + task1_cache_str = str(task1_cache_dir) + task2_cache_str = str(task2_cache_dir) + global_cache_str = str(global_cache) + # Verify all caches exist self.assertTrue(task1_cache_dir.exists()) self.assertTrue(task2_cache_dir.exists()) @@ -284,10 +292,11 @@ def test_clear_task_cache_removes_only_specified_task(self): # Clear only task1 cache self.dataset.clear_task_cache(task1) - # Verify task1 cache is removed but others remain - self.assertFalse(task1_cache_dir.exists()) - self.assertTrue(task2_cache_dir.exists()) - self.assertTrue(global_cache.exists()) + # Verify task1 cache is removed but others remain (use stored paths) + from pathlib import Path + self.assertFalse(Path(task1_cache_str).exists()) + self.assertTrue(Path(task2_cache_str).exists()) + self.assertTrue(Path(global_cache_str).exists()) sample_dataset1.close() sample_dataset2.close()