@@ -118,7 +118,7 @@ def _csv_tsv_gz_path(path: str) -> str:
118118class _ProgressContext :
119119 def __init__ (self , queue : multiprocessing .queues .Queue | None , total : int , ** kwargs ):
120120 """
121- :param queue: An existing queue (e.g., from multiprocessing). If provided,
121+ :param queue: An existing queue (e.g., from multiprocessing). If provided,
122122 this class acts as a passthrough.
123123 :param total: Total items for the progress bar (only used if queue is None).
124124 :param kwargs: Extra arguments for tqdm (e.g., desc="Processing").
@@ -135,7 +135,7 @@ def put(self, n):
135135 def __enter__ (self ):
136136 if self .queue :
137137 return self .queue
138-
138+
139139 self .progress = tqdm (total = self .total , ** self .kwargs )
140140 return self
141141
@@ -158,7 +158,7 @@ def _task_transform_init(queue: multiprocessing.queues.Queue) -> None:
158158def _task_transform_fn (args : tuple [int , BaseTask , Iterable [str ], pl .LazyFrame , Path ]) -> None :
159159 """
160160 Worker function to apply task transformation on a chunk of patients.
161-
161+
162162 Args:
163163 args (tuple): A tuple containing:
164164 worker_id (int): The ID of the worker.
@@ -171,13 +171,13 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P
171171 worker_id , task , patient_ids , global_event_df , output_dir = args
172172 total_patients = len (list (patient_ids ))
173173 logger .info (f"Worker { worker_id } started processing { total_patients } patients. (Polars threads: { pl .thread_pool_size ()} )" )
174-
174+
175175 with (
176- set_env (DATA_OPTIMIZER_GLOBAL_RANK = str (worker_id )),
176+ set_env (DATA_OPTIMIZER_GLOBAL_RANK = str (worker_id )),
177177 _ProgressContext (_task_transform_progress , total = total_patients ) as progress
178178 ):
179179 writer = BinaryWriter (cache_dir = str (output_dir ), chunk_bytes = "64MB" )
180-
180+
181181 write_index = 0
182182 batches = itertools .batched (patient_ids , BATCH_SIZE )
183183 for batch in batches :
@@ -210,11 +210,11 @@ def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None:
210210 """
211211 global _proc_transform_progress
212212 _proc_transform_progress = queue
213-
213+
214214def _proc_transform_fn (args : tuple [int , Path , int , int , Path ]) -> None :
215215 """
216216 Worker function to apply processors on a chunk of samples.
217-
217+
218218 Args:
219219 args (tuple): A tuple containing:
220220 worker_id (int): The ID of the worker.
@@ -233,15 +233,15 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None:
233233 _ProgressContext (_proc_transform_progress , total = total_samples ) as progress
234234 ):
235235 writer = BinaryWriter (cache_dir = str (output_dir ), chunk_bytes = "64MB" )
236-
236+
237237 dataset = litdata .StreamingDataset (str (task_df ))
238238 complete = 0
239239 with open (f"{ output_dir } /schema.pkl" , "rb" ) as f :
240240 metadata = pickle .load (f )
241241
242242 input_processors = metadata ["input_processors" ]
243243 output_processors = metadata ["output_processors" ]
244-
244+
245245 write_index = 0
246246 for i in range (start_idx , end_idx ):
247247 transformed : Dict [str , Any ] = {}
@@ -255,7 +255,7 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None:
255255 writer .add_item (write_index , transformed )
256256 write_index += 1
257257 complete += 1
258-
258+
259259 if complete >= BATCH_SIZE :
260260 progress .put (complete )
261261 complete = 0
@@ -680,7 +680,7 @@ def default_task(self) -> Optional[BaseTask]:
680680
681681 def _task_transform (self , task : BaseTask , output_dir : Path , num_workers : int ) -> None :
682682 self ._main_guard (self ._task_transform .__name__ )
683-
683+
684684 logger .info (f"Applying task transformations on data with { num_workers } workers..." )
685685 global_event_df = task .pre_filter (self .global_event_df )
686686 patient_ids = (
@@ -691,16 +691,16 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
691691 # .sort can reduce runtime by 5%.
692692 .sort ()
693693 )
694-
694+
695695 if in_notebook ():
696696 logger .info ("Detected Jupyter notebook environment, setting num_workers to 1" )
697697 num_workers = 1
698698 num_workers = min (num_workers , len (patient_ids )) # Avoid spawning empty workers
699-
699+
700700 # This ensures worker's polars threads are limited to avoid oversubscription,
701701 # which can lead to additional 75% speedup when num_workers is large.
702702 threads_per_worker = max (1 , (os .cpu_count () or 1 ) // num_workers )
703-
703+
704704 try :
705705 with set_env (POLARS_MAX_THREADS = str (threads_per_worker ), DATA_OPTIMIZER_NUM_WORKERS = str (num_workers )):
706706 if num_workers == 1 :
@@ -727,7 +727,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
727727 progress .update (queue .get (timeout = 1 ))
728728 except :
729729 pass
730-
730+
731731 # remaining items
732732 while not queue .empty ():
733733 progress .update (queue .get ())
@@ -739,17 +739,17 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
739739 logger .error (f"Error during task transformation, cleaning up output directory: { output_dir } " )
740740 shutil .rmtree (output_dir )
741741 raise e
742-
742+
743743 def _proc_transform (self , task_df : Path , output_dir : Path , num_workers : int ) -> None :
744744 self ._main_guard (self ._proc_transform .__name__ )
745-
745+
746746 logger .info (f"Applying processors on data with { num_workers } workers..." )
747747 num_samples = len (litdata .StreamingDataset (str (task_df )))
748-
748+
749749 if in_notebook ():
750750 logger .info ("Detected Jupyter notebook environment, setting num_workers to 1" )
751751 num_workers = 1
752-
752+
753753 num_workers = min (num_workers , num_samples ) # Avoid spawning empty workers
754754 try :
755755 with set_env (DATA_OPTIMIZER_NUM_WORKERS = str (num_workers )):
@@ -758,7 +758,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
758758 _proc_transform_fn ((0 , task_df , 0 , num_samples , output_dir ))
759759 BinaryWriter (cache_dir = str (output_dir ), chunk_bytes = "64MB" ).merge (num_workers )
760760 return
761-
761+
762762 ctx = multiprocessing .get_context ("spawn" )
763763 queue = ctx .Queue ()
764764 linspace = more_itertools .sliding_window (np .linspace (0 , num_samples , num_workers + 1 , dtype = int ), 2 )
@@ -777,7 +777,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
777777 progress .update (queue .get (timeout = 1 ))
778778 except :
779779 pass
780-
780+
781781 # remaining items
782782 while not queue .empty ():
783783 progress .update (queue .get ())
@@ -814,8 +814,8 @@ def set_task(
814814 Args:
815815 task (Optional[BaseTask]): The task to set. Uses default task if None.
816816 num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`.
817- cache_dir (Optional[str]): Directory to cache samples after task transformation,
818- but without applying processors. Default is {self.cache_dir}/tasks/{task_name}.
817+ cache_dir (Optional[str]): Directory to cache samples after task transformation,
818+ but without applying processors. Default is {self.cache_dir}/tasks/{task_name}_{uuid5(vars(task))} .
819819 cache_format (str): Deprecated. Only "parquet" is supported now.
820820 input_processors (Optional[Dict[str, FeatureProcessor]]):
821821 Pre-fitted input processors. If provided, these will be used
@@ -835,7 +835,7 @@ def set_task(
835835 if task is None :
836836 assert self .default_task is not None , "No default tasks found"
837837 task = self .default_task
838-
838+
839839 if num_workers is None :
840840 num_workers = self .num_workers
841841
@@ -846,8 +846,14 @@ def set_task(
846846 f"Setting task { task .task_name } for { self .dataset_name } base dataset..."
847847 )
848848
849+ task_params = json .dumps (
850+ vars (task ),
851+ sort_keys = True ,
852+ default = str
853+ )
854+
849855 if cache_dir is None :
850- cache_dir = self .cache_dir / "tasks" / task .task_name
856+ cache_dir = self .cache_dir / "tasks" / f" { task .task_name } _ { uuid . uuid5 ( uuid . NAMESPACE_DNS , task_params ) } "
851857 cache_dir .mkdir (parents = True , exist_ok = True )
852858 else :
853859 # Ensure the explicitly provided cache_dir exists
@@ -856,7 +862,7 @@ def set_task(
856862
857863 task_df_path = Path (cache_dir ) / "task_df.ld"
858864 samples_path = Path (cache_dir ) / f"samples_{ uuid .uuid4 ()} .ld"
859-
865+
860866 task_df_path .mkdir (parents = True , exist_ok = True )
861867 samples_path .mkdir (parents = True , exist_ok = True )
862868
0 commit comments