@@ -849,44 +849,74 @@ def set_task(
849849 cache_dir = Path (cache_dir )
850850 cache_dir .mkdir (parents = True , exist_ok = True )
851851
852+ proc_params = json .dumps (
853+ {
854+ "input_schema" : task .input_schema ,
855+ "output_schema" : task .output_schema ,
856+ "input_processors" : (
857+ {
858+ f"{ k } _{ v .__class__ .__name__ } " : vars (v )
859+ for k , v in input_processors .items ()
860+ }
861+ if input_processors
862+ else None
863+ ),
864+ "output_processors" : (
865+ {
866+ f"{ k } _{ v .__class__ .__name__ } " : vars (v )
867+ for k , v in output_processors .items ()
868+ }
869+ if output_processors
870+ else None
871+ ),
872+ },
873+ sort_keys = True ,
874+ default = str
875+ )
876+
852877 task_df_path = Path (cache_dir ) / "task_df.ld"
853- samples_path = Path (cache_dir ) / f"samples_{ uuid .uuid4 ( )} .ld"
878+ samples_path = Path (cache_dir ) / f"samples_{ uuid .uuid5 ( uuid . NAMESPACE_DNS , proc_params )} .ld"
854879
855880 task_df_path .mkdir (parents = True , exist_ok = True )
856881 samples_path .mkdir (parents = True , exist_ok = True )
882+
883+ if not (samples_path / "index.json" ).exists ():
884+ # Check if index.json exists to verify cache integrity, this
885+ # is the standard file for litdata.StreamingDataset
886+ if not (task_df_path / "index.json" ).exists ():
887+ self ._task_transform (
888+ task ,
889+ task_df_path ,
890+ num_workers ,
891+ )
892+ else :
893+ logger .info (f"Found cached task dataframe at { task_df_path } , skipping task transformation." )
857894
858- # Check if index.json exists to verify cache integrity, this
859- # is the standard file for litdata.StreamingDataset
860- if not (task_df_path / "index.json" ).exists ():
861- self ._task_transform (
862- task ,
895+ # Build processors and fit on the dataset
896+ logger .info (f"Fitting processors on the dataset..." )
897+ dataset = litdata .StreamingDataset (
898+ str (task_df_path ),
899+ transform = lambda x : pickle .loads (x ["sample" ]),
900+ )
901+ builder = SampleBuilder (
902+ input_schema = task .input_schema , # type: ignore
903+ output_schema = task .output_schema , # type: ignore
904+ input_processors = input_processors ,
905+ output_processors = output_processors ,
906+ )
907+ builder .fit (dataset )
908+ builder .save (str (samples_path / "schema.pkl" ))
909+
910+ # Apply processors and save final samples to cache_dir
911+ logger .info (f"Processing samples and saving to { samples_path } ..." )
912+ self ._proc_transform (
863913 task_df_path ,
914+ samples_path ,
864915 num_workers ,
865916 )
866-
867- # Build processors and fit on the dataset
868- logger .info (f"Fitting processors on the dataset..." )
869- dataset = litdata .StreamingDataset (
870- str (task_df_path ),
871- transform = lambda x : pickle .loads (x ["sample" ]),
872- )
873- builder = SampleBuilder (
874- input_schema = task .input_schema , # type: ignore
875- output_schema = task .output_schema , # type: ignore
876- input_processors = input_processors ,
877- output_processors = output_processors ,
878- )
879- builder .fit (dataset )
880- builder .save (str (samples_path / "schema.pkl" ))
881-
882- # Apply processors and save final samples to cache_dir
883- logger .info (f"Processing samples and saving to { samples_path } ..." )
884- self ._proc_transform (
885- task_df_path ,
886- samples_path ,
887- num_workers ,
888- )
889- logger .info (f"Cached processed samples to { samples_path } " )
917+ logger .info (f"Cached processed samples to { samples_path } " )
918+ else :
919+ logger .info (f"Found cached processed samples at { samples_path } , skipping processing." )
890920
891921 return SampleDataset (
892922 path = str (samples_path ),
@@ -902,4 +932,4 @@ def _main_guard(self, func_name: str):
902932 f"{ func_name } method accessed from a non-main process. This may lead to unexpected behavior.\n "
903933 + "Consider use __name__ == '__main__' guard when using multiprocessing."
904934 )
905- exit (1 )
935+ exit (1 )
0 commit comments