diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 91a1c95a..f40676d8 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -849,44 +849,74 @@ def set_task( cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) + proc_params = json.dumps( + { + "input_schema": task.input_schema, + "output_schema": task.output_schema, + "input_processors": ( + { + f"{k}_{v.__class__.__name__}": vars(v) + for k, v in input_processors.items() + } + if input_processors + else None + ), + "output_processors": ( + { + f"{k}_{v.__class__.__name__}": vars(v) + for k, v in output_processors.items() + } + if output_processors + else None + ), + }, + sort_keys=True, + default=str + ) + task_df_path = Path(cache_dir) / "task_df.ld" - samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}.ld" + samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) + + if not (samples_path / "index.json").exists(): + # Check if index.json exists to verify cache integrity, this + # is the standard file for litdata.StreamingDataset + if not (task_df_path / "index.json").exists(): + self._task_transform( + task, + task_df_path, + num_workers, + ) + else: + logger.info(f"Found cached task dataframe at {task_df_path}, skipping task transformation.") - # Check if index.json exists to verify cache integrity, this - # is the standard file for litdata.StreamingDataset - if not (task_df_path / "index.json").exists(): - self._task_transform( - task, + # Build processors and fit on the dataset + logger.info(f"Fitting processors on the dataset...") + dataset = litdata.StreamingDataset( + str(task_df_path), + transform=lambda x: pickle.loads(x["sample"]), + ) + builder = SampleBuilder( + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(dataset) + builder.save(str(samples_path / "schema.pkl")) + + # Apply processors and save final samples to cache_dir + logger.info(f"Processing samples and saving to {samples_path}...") + self._proc_transform( task_df_path, + samples_path, num_workers, ) - - # Build processors and fit on the dataset - logger.info(f"Fitting processors on the dataset...") - dataset = litdata.StreamingDataset( - str(task_df_path), - transform=lambda x: pickle.loads(x["sample"]), - ) - builder = SampleBuilder( - input_schema=task.input_schema, # type: ignore - output_schema=task.output_schema, # type: ignore - input_processors=input_processors, - output_processors=output_processors, - ) - builder.fit(dataset) - builder.save(str(samples_path / "schema.pkl")) - - # Apply processors and save final samples to cache_dir - logger.info(f"Processing samples and saving to {samples_path}...") - self._proc_transform( - task_df_path, - samples_path, - num_workers, - ) - logger.info(f"Cached processed samples to {samples_path}") + logger.info(f"Cached processed samples to {samples_path}") + else: + logger.info(f"Found cached processed samples at {samples_path}, skipping processing.") return SampleDataset( path=str(samples_path), @@ -902,4 +932,4 @@ def _main_guard(self, func_name: str): f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n" + "Consider use __name__ == '__main__' guard when using multiprocessing." ) - exit(1) \ No newline at end of file + exit(1)