Skip to content

Commit 7d734f2

Browse files
authored
Add option to cache transformed data from processors and skip pipeline entirely (#783)
* skip proc transformation if it already exists * Fix sort_key * Fix test * remove hex for backward compitability
1 parent f5d0ad5 commit 7d734f2

File tree

1 file changed

+61
-31
lines changed

1 file changed

+61
-31
lines changed

pyhealth/datasets/base_dataset.py

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -849,44 +849,74 @@ def set_task(
849849
cache_dir = Path(cache_dir)
850850
cache_dir.mkdir(parents=True, exist_ok=True)
851851

852+
proc_params = json.dumps(
853+
{
854+
"input_schema": task.input_schema,
855+
"output_schema": task.output_schema,
856+
"input_processors": (
857+
{
858+
f"{k}_{v.__class__.__name__}": vars(v)
859+
for k, v in input_processors.items()
860+
}
861+
if input_processors
862+
else None
863+
),
864+
"output_processors": (
865+
{
866+
f"{k}_{v.__class__.__name__}": vars(v)
867+
for k, v in output_processors.items()
868+
}
869+
if output_processors
870+
else None
871+
),
872+
},
873+
sort_keys=True,
874+
default=str
875+
)
876+
852877
task_df_path = Path(cache_dir) / "task_df.ld"
853-
samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}.ld"
878+
samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld"
854879

855880
task_df_path.mkdir(parents=True, exist_ok=True)
856881
samples_path.mkdir(parents=True, exist_ok=True)
882+
883+
if not (samples_path / "index.json").exists():
884+
# Check if index.json exists to verify cache integrity, this
885+
# is the standard file for litdata.StreamingDataset
886+
if not (task_df_path / "index.json").exists():
887+
self._task_transform(
888+
task,
889+
task_df_path,
890+
num_workers,
891+
)
892+
else:
893+
logger.info(f"Found cached task dataframe at {task_df_path}, skipping task transformation.")
857894

858-
# Check if index.json exists to verify cache integrity, this
859-
# is the standard file for litdata.StreamingDataset
860-
if not (task_df_path / "index.json").exists():
861-
self._task_transform(
862-
task,
895+
# Build processors and fit on the dataset
896+
logger.info(f"Fitting processors on the dataset...")
897+
dataset = litdata.StreamingDataset(
898+
str(task_df_path),
899+
transform=lambda x: pickle.loads(x["sample"]),
900+
)
901+
builder = SampleBuilder(
902+
input_schema=task.input_schema, # type: ignore
903+
output_schema=task.output_schema, # type: ignore
904+
input_processors=input_processors,
905+
output_processors=output_processors,
906+
)
907+
builder.fit(dataset)
908+
builder.save(str(samples_path / "schema.pkl"))
909+
910+
# Apply processors and save final samples to cache_dir
911+
logger.info(f"Processing samples and saving to {samples_path}...")
912+
self._proc_transform(
863913
task_df_path,
914+
samples_path,
864915
num_workers,
865916
)
866-
867-
# Build processors and fit on the dataset
868-
logger.info(f"Fitting processors on the dataset...")
869-
dataset = litdata.StreamingDataset(
870-
str(task_df_path),
871-
transform=lambda x: pickle.loads(x["sample"]),
872-
)
873-
builder = SampleBuilder(
874-
input_schema=task.input_schema, # type: ignore
875-
output_schema=task.output_schema, # type: ignore
876-
input_processors=input_processors,
877-
output_processors=output_processors,
878-
)
879-
builder.fit(dataset)
880-
builder.save(str(samples_path / "schema.pkl"))
881-
882-
# Apply processors and save final samples to cache_dir
883-
logger.info(f"Processing samples and saving to {samples_path}...")
884-
self._proc_transform(
885-
task_df_path,
886-
samples_path,
887-
num_workers,
888-
)
889-
logger.info(f"Cached processed samples to {samples_path}")
917+
logger.info(f"Cached processed samples to {samples_path}")
918+
else:
919+
logger.info(f"Found cached processed samples at {samples_path}, skipping processing.")
890920

891921
return SampleDataset(
892922
path=str(samples_path),
@@ -902,4 +932,4 @@ def _main_guard(self, func_name: str):
902932
f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n"
903933
+ "Consider use __name__ == '__main__' guard when using multiprocessing."
904934
)
905-
exit(1)
935+
exit(1)

0 commit comments

Comments
 (0)