@@ -905,6 +905,78 @@ def set_task(
905905 task_name = task .task_name ,
906906 )
907907
908+ def clear_cache (self ) -> None :
909+ """Clears the entire dataset cache, including global event data and all task caches.
910+
911+ This method removes:
912+ - The global event dataframe cache (global_event_df.parquet)
913+ - All task-specific caches in the tasks/ directory
914+ - Temporary files in the tmp/ directory
915+
916+ After calling this method, the dataset will need to be reprocessed from
917+ scratch the next time it is accessed.
918+
919+ Note:
920+ This operation cannot be undone. Use with caution.
921+ """
922+ cache_path = self .cache_dir
923+
924+ if cache_path .exists ():
925+ logger .info (f"Clearing entire dataset cache at { cache_path } " )
926+ try :
927+ shutil .rmtree (cache_path )
928+ logger .info (f"Successfully cleared dataset cache at { cache_path } " )
929+
930+ # Reset cached attributes since the cache has been cleared
931+ self ._cache_dir = None
932+ self ._global_event_df = None
933+ self ._unique_patient_ids = None
934+ except Exception as e :
935+ logger .error (f"Failed to clear cache at { cache_path } : { e } " )
936+ raise
937+ else :
938+ logger .info (f"No cache found at { cache_path } , nothing to clear" )
939+
940+ def clear_task_cache (self , task : Optional [BaseTask ] = None ) -> None :
941+ """Clears the cache for a specific task.
942+
943+ This method removes only the task-specific cache directory for the given task,
944+ preserving the global event dataframe cache and other task caches.
945+
946+ Args:
947+ task (Optional[BaseTask]): The task whose cache should be cleared.
948+ If None, uses the default task.
949+
950+ Raises:
951+ AssertionError: If no default task is found and task is None.
952+
953+ Note:
954+ This operation cannot be undone. The task cache will need to be
955+ regenerated the next time set_task is called with this task.
956+ """
957+ if task is None :
958+ assert self .default_task is not None , "No default task found"
959+ task = self .default_task
960+
961+ # Generate the same task cache directory name as in set_task
962+ task_params = json .dumps (
963+ vars (task ),
964+ sort_keys = True ,
965+ default = str
966+ )
967+ task_cache_dir = self .cache_dir / "tasks" / f"{ task .task_name } _{ uuid .uuid5 (uuid .NAMESPACE_DNS , task_params )} "
968+
969+ if task_cache_dir .exists ():
970+ logger .info (f"Clearing task cache for '{ task .task_name } ' at { task_cache_dir } " )
971+ try :
972+ shutil .rmtree (task_cache_dir )
973+ logger .info (f"Successfully cleared task cache for '{ task .task_name } '" )
974+ except Exception as e :
975+ logger .error (f"Failed to clear task cache at { task_cache_dir } : { e } " )
976+ raise
977+ else :
978+ logger .info (f"No cache found for task '{ task .task_name } ' at { task_cache_dir } , nothing to clear" )
979+
908980 def _main_guard (self , func_name : str ):
909981 """Warn if method is accessed from a non-main process."""
910982
0 commit comments