@@ -115,6 +115,30 @@ def _csv_tsv_gz_path(path: str) -> str:
115115
116116 raise FileNotFoundError (f"Neither path exists: { path } or { alt_path } " )
117117
118+
119+ def _litdata_merge (cache_dir : Path ) -> None :
120+ """
121+ Merges LitData binary writer index files in the given cache directory.
122+
123+ Args:
124+ cache_dir (Path): The cache directory containing LitData binary writer files.
125+ """
126+ from litdata .streaming .writer import _INDEX_FILENAME
127+ files = os .listdir (cache_dir )
128+
129+ # Return if the index already exists
130+ if _INDEX_FILENAME in files :
131+ return
132+
133+ index_files = [f for f in files if f .endswith (_INDEX_FILENAME )]
134+
135+ # Return if there are no index files to merge
136+ if len (index_files ) == 0 :
137+ raise ValueError ("There are zero samples in the dataset, please check the task and processors." )
138+
139+ BinaryWriter (cache_dir = str (cache_dir ), chunk_bytes = "64MB" ).merge (num_workers = len (index_files ))
140+
141+
118142class _ProgressContext :
119143 def __init__ (self , queue : multiprocessing .queues .Queue | None , total : int , ** kwargs ):
120144 """
@@ -695,7 +719,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
695719 if num_workers == 1 :
696720 logger .info ("Single worker mode, processing sequentially" )
697721 _task_transform_fn ((0 , task , patient_ids , global_event_df , output_dir ))
698- BinaryWriter ( cache_dir = str ( output_dir ), chunk_bytes = "64MB" ). merge ( num_workers )
722+ _litdata_merge ( output_dir )
699723 return
700724
701725 # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary
@@ -721,7 +745,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
721745 while not queue .empty ():
722746 progress .update (queue .get ())
723747 result .get () # ensure exceptions are raised
724- BinaryWriter ( cache_dir = str ( output_dir ), chunk_bytes = "64MB" ). merge ( num_workers )
748+ _litdata_merge ( output_dir )
725749
726750 logger .info (f"Task transformation completed and saved to { output_dir } " )
727751 except Exception as e :
@@ -745,7 +769,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
745769 if num_workers == 1 :
746770 logger .info ("Single worker mode, processing sequentially" )
747771 _proc_transform_fn ((0 , task_df , 0 , num_samples , output_dir ))
748- BinaryWriter ( cache_dir = str ( output_dir ), chunk_bytes = "64MB" ). merge ( num_workers )
772+ _litdata_merge ( output_dir )
749773 return
750774
751775 ctx = multiprocessing .get_context ("spawn" )
@@ -771,7 +795,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
771795 while not queue .empty ():
772796 progress .update (queue .get ())
773797 result .get () # ensure exceptions are raised
774- BinaryWriter ( cache_dir = str ( output_dir ), chunk_bytes = "64MB" ). merge ( num_workers )
798+ _litdata_merge ( output_dir )
775799
776800 logger .info (f"Processor transformation completed and saved to { output_dir } " )
777801 except Exception as e :
0 commit comments