Skip to content

Commit 82011b0

Browse files
authored
Update the default task cache path to include task parameter names and values (#766)
* Document ReadmissionPredictionMIMIC3 as a class instead of a function * Call `close` on sample datasets * Remove the TemporaryDirectory.cleanup() calls as cleanup will be called automatically when the current context is exited * Updated the default task cache path to include task parameter names and values * Use UUID v5 in task cache names * Update task cache_dir docs for UUID v5 based default
1 parent a3750b0 commit 82011b0

File tree

6 files changed

+134
-98
lines changed

6 files changed

+134
-98
lines changed

docs/api/tasks/pyhealth.tasks.readmission_prediction.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
pyhealth.tasks.readmission_prediction
22
=======================================
33

4-
.. autofunction:: pyhealth.tasks.readmission_prediction.ReadmissionPredictionMIMIC3
4+
.. autoclass:: pyhealth.tasks.readmission_prediction.ReadmissionPredictionMIMIC3
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+
59
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_mimic4_fn
610
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn
711
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn2

pyhealth/datasets/base_dataset.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _csv_tsv_gz_path(path: str) -> str:
118118
class _ProgressContext:
119119
def __init__(self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs):
120120
"""
121-
:param queue: An existing queue (e.g., from multiprocessing). If provided,
121+
:param queue: An existing queue (e.g., from multiprocessing). If provided,
122122
this class acts as a passthrough.
123123
:param total: Total items for the progress bar (only used if queue is None).
124124
:param kwargs: Extra arguments for tqdm (e.g., desc="Processing").
@@ -135,7 +135,7 @@ def put(self, n):
135135
def __enter__(self):
136136
if self.queue:
137137
return self.queue
138-
138+
139139
self.progress = tqdm(total=self.total, **self.kwargs)
140140
return self
141141

@@ -158,7 +158,7 @@ def _task_transform_init(queue: multiprocessing.queues.Queue) -> None:
158158
def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None:
159159
"""
160160
Worker function to apply task transformation on a chunk of patients.
161-
161+
162162
Args:
163163
args (tuple): A tuple containing:
164164
worker_id (int): The ID of the worker.
@@ -171,13 +171,13 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P
171171
worker_id, task, patient_ids, global_event_df, output_dir = args
172172
total_patients = len(list(patient_ids))
173173
logger.info(f"Worker {worker_id} started processing {total_patients} patients. (Polars threads: {pl.thread_pool_size()})")
174-
174+
175175
with (
176-
set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)),
176+
set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)),
177177
_ProgressContext(_task_transform_progress, total=total_patients) as progress
178178
):
179179
writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB")
180-
180+
181181
write_index = 0
182182
batches = itertools.batched(patient_ids, BATCH_SIZE)
183183
for batch in batches:
@@ -210,11 +210,11 @@ def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None:
210210
"""
211211
global _proc_transform_progress
212212
_proc_transform_progress = queue
213-
213+
214214
def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None:
215215
"""
216216
Worker function to apply processors on a chunk of samples.
217-
217+
218218
Args:
219219
args (tuple): A tuple containing:
220220
worker_id (int): The ID of the worker.
@@ -233,15 +233,15 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None:
233233
_ProgressContext(_proc_transform_progress, total=total_samples) as progress
234234
):
235235
writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB")
236-
236+
237237
dataset = litdata.StreamingDataset(str(task_df))
238238
complete = 0
239239
with open(f"{output_dir}/schema.pkl", "rb") as f:
240240
metadata = pickle.load(f)
241241

242242
input_processors = metadata["input_processors"]
243243
output_processors = metadata["output_processors"]
244-
244+
245245
write_index = 0
246246
for i in range(start_idx, end_idx):
247247
transformed: Dict[str, Any] = {}
@@ -255,7 +255,7 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None:
255255
writer.add_item(write_index, transformed)
256256
write_index += 1
257257
complete += 1
258-
258+
259259
if complete >= BATCH_SIZE:
260260
progress.put(complete)
261261
complete = 0
@@ -680,7 +680,7 @@ def default_task(self) -> Optional[BaseTask]:
680680

681681
def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> None:
682682
self._main_guard(self._task_transform.__name__)
683-
683+
684684
logger.info(f"Applying task transformations on data with {num_workers} workers...")
685685
global_event_df = task.pre_filter(self.global_event_df)
686686
patient_ids = (
@@ -691,16 +691,16 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
691691
# .sort can reduce runtime by 5%.
692692
.sort()
693693
)
694-
694+
695695
if in_notebook():
696696
logger.info("Detected Jupyter notebook environment, setting num_workers to 1")
697697
num_workers = 1
698698
num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers
699-
699+
700700
# This ensures worker's polars threads are limited to avoid oversubscription,
701701
# which can lead to additional 75% speedup when num_workers is large.
702702
threads_per_worker = max(1, (os.cpu_count() or 1) // num_workers)
703-
703+
704704
try:
705705
with set_env(POLARS_MAX_THREADS=str(threads_per_worker), DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)):
706706
if num_workers == 1:
@@ -727,7 +727,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
727727
progress.update(queue.get(timeout=1))
728728
except:
729729
pass
730-
730+
731731
# remaining items
732732
while not queue.empty():
733733
progress.update(queue.get())
@@ -739,17 +739,17 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
739739
logger.error(f"Error during task transformation, cleaning up output directory: {output_dir}")
740740
shutil.rmtree(output_dir)
741741
raise e
742-
742+
743743
def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> None:
744744
self._main_guard(self._proc_transform.__name__)
745-
745+
746746
logger.info(f"Applying processors on data with {num_workers} workers...")
747747
num_samples = len(litdata.StreamingDataset(str(task_df)))
748-
748+
749749
if in_notebook():
750750
logger.info("Detected Jupyter notebook environment, setting num_workers to 1")
751751
num_workers = 1
752-
752+
753753
num_workers = min(num_workers, num_samples) # Avoid spawning empty workers
754754
try:
755755
with set_env(DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)):
@@ -758,7 +758,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
758758
_proc_transform_fn((0, task_df, 0, num_samples, output_dir))
759759
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
760760
return
761-
761+
762762
ctx = multiprocessing.get_context("spawn")
763763
queue = ctx.Queue()
764764
linspace = more_itertools.sliding_window(np.linspace(0, num_samples, num_workers + 1, dtype=int), 2)
@@ -777,7 +777,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
777777
progress.update(queue.get(timeout=1))
778778
except:
779779
pass
780-
780+
781781
# remaining items
782782
while not queue.empty():
783783
progress.update(queue.get())
@@ -814,8 +814,8 @@ def set_task(
814814
Args:
815815
task (Optional[BaseTask]): The task to set. Uses default task if None.
816816
num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`.
817-
cache_dir (Optional[str]): Directory to cache samples after task transformation,
818-
but without applying processors. Default is {self.cache_dir}/tasks/{task_name}.
817+
cache_dir (Optional[str]): Directory to cache samples after task transformation,
818+
but without applying processors. Default is {self.cache_dir}/tasks/{task_name}_{uuid5(vars(task))}.
819819
cache_format (str): Deprecated. Only "parquet" is supported now.
820820
input_processors (Optional[Dict[str, FeatureProcessor]]):
821821
Pre-fitted input processors. If provided, these will be used
@@ -835,7 +835,7 @@ def set_task(
835835
if task is None:
836836
assert self.default_task is not None, "No default tasks found"
837837
task = self.default_task
838-
838+
839839
if num_workers is None:
840840
num_workers = self.num_workers
841841

@@ -846,8 +846,14 @@ def set_task(
846846
f"Setting task {task.task_name} for {self.dataset_name} base dataset..."
847847
)
848848

849+
task_params = json.dumps(
850+
vars(task),
851+
sort_keys=True,
852+
default=str
853+
)
854+
849855
if cache_dir is None:
850-
cache_dir = self.cache_dir / "tasks" / task.task_name
856+
cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}"
851857
cache_dir.mkdir(parents=True, exist_ok=True)
852858
else:
853859
# Ensure the explicitly provided cache_dir exists
@@ -856,7 +862,7 @@ def set_task(
856862

857863
task_df_path = Path(cache_dir) / "task_df.ld"
858864
samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}.ld"
859-
865+
860866
task_df_path.mkdir(parents=True, exist_ok=True)
861867
samples_path.mkdir(parents=True, exist_ok=True)
862868

pyhealth/datasets/chestxray14.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def __init__(self,
5353
root: str = ".",
5454
config_path: Optional[str] = str(Path(__file__).parent / "configs" / "chestxray14.yaml"),
5555
download: bool = False,
56-
partial: bool = False) -> None:
56+
partial: bool = False,
57+
**kwargs) -> None:
5758
"""Initializes the ChestX-ray14 dataset.
5859
5960
Args:
@@ -87,6 +88,7 @@ def __init__(self,
8788
tables=["chestxray14"],
8889
dataset_name="ChestX-ray14",
8990
config_path=config_path,
91+
**kwargs
9092
)
9193

9294
@property

tests/core/test_caching.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import shutil
44
from pathlib import Path
55
from unittest.mock import patch
6-
import polars as pl
76
import dask.dataframe as dd
87
import torch
8+
import json
9+
import uuid
910

1011
from tests.base import BaseTestCase
1112
from pyhealth.datasets.base_dataset import BaseDataset
@@ -15,12 +16,14 @@
1516

1617
class MockTask(BaseTask):
1718
"""Mock task for testing purposes."""
19+
task_name = "test_task"
20+
input_schema = {"test_attribute": "raw"}
21+
output_schema = {"test_label": "binary"}
1822

19-
def __init__(self, task_name="test_task"):
20-
self.task_name = task_name
21-
self.input_schema = {"test_attribute": "raw"}
22-
self.output_schema = {"test_label": "binary"}
23+
def __init__(self, param=None):
2324
self.call_count = 0
25+
if param:
26+
self.param = param
2427

2528
def __call__(self, patient):
2629
"""Return mock samples based on patient data."""
@@ -77,20 +80,18 @@ def load_data(self) -> dd.DataFrame:
7780
class TestCachingFunctionality(BaseTestCase):
7881
"""Test cases for caching functionality in BaseDataset.set_task()."""
7982

83+
@classmethod
84+
def setUpClass(cls):
85+
cls.temp_dir = tempfile.TemporaryDirectory()
86+
cls.dataset = MockDataset(cache_dir=cls.temp_dir.name)
87+
8088
def setUp(self):
81-
"""Set up test fixtures."""
82-
self.temp_dir = Path(tempfile.mkdtemp())
83-
self.dataset = MockDataset(cache_dir=self.temp_dir)
8489
self.task = MockTask()
90+
self.cache_dir = Path(self.temp_dir.name) / "task_cache"
91+
self.cache_dir.mkdir()
8592

8693
def tearDown(self):
87-
"""Clean up test fixtures."""
88-
shutil.rmtree(self.temp_dir, ignore_errors=True)
89-
90-
def _task_cache_dir(self) -> Path:
91-
cache_dir = self.temp_dir / "task_cache"
92-
cache_dir.mkdir(parents=True, exist_ok=True)
93-
return cache_dir
94+
shutil.rmtree(self.cache_dir)
9495

9596
def test_set_task_signature(self):
9697
"""Test that set_task has the correct method signature."""
@@ -120,9 +121,8 @@ def test_set_task_signature(self):
120121

121122
def test_set_task_writes_cache_and_metadata(self):
122123
"""Ensure set_task materializes cache files and schema metadata."""
123-
cache_dir = self._task_cache_dir()
124124
with self.dataset.set_task(
125-
self.task, cache_dir=cache_dir, cache_format="parquet"
125+
self.task, cache_dir=self.cache_dir, cache_format="parquet"
126126
) as sample_dataset:
127127
self.assertIsInstance(sample_dataset, SampleDataset)
128128
self.assertEqual(sample_dataset.dataset_name, "TestDataset")
@@ -131,7 +131,7 @@ def test_set_task_writes_cache_and_metadata(self):
131131
self.assertEqual(self.task.call_count, 2)
132132

133133
# Ensure intermediate cache files are created
134-
self.assertTrue((cache_dir / "task_df.ld" / "index.json").exists())
134+
self.assertTrue((self.cache_dir / "task_df.ld" / "index.json").exists())
135135

136136
# Cache artifacts should be present for StreamingDataset
137137
assert sample_dataset.input_dir.path is not None
@@ -156,35 +156,75 @@ def test_set_task_writes_cache_and_metadata(self):
156156
self.assertFalse((sample_dir / "index.json").exists())
157157
self.assertFalse((sample_dir / "schema.pkl").exists())
158158
# Ensure intermediate cache files are still present
159-
self.assertTrue((cache_dir / "task_df.ld" / "index.json").exists())
159+
self.assertTrue((self.cache_dir / "task_df.ld" / "index.json").exists())
160160

161161

162162
def test_default_cache_dir_is_used(self):
163163
"""When cache_dir is omitted, default cache dir should be used."""
164-
task_cache = self.dataset.cache_dir / "tasks" / self.task.task_name
164+
task_params = json.dumps(
165+
{"call_count": 0},
166+
sort_keys=True,
167+
default=str
168+
)
169+
170+
task_cache = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}"
165171
sample_dataset = self.dataset.set_task(self.task)
166172

167173
self.assertTrue(task_cache.exists())
168174
self.assertTrue((task_cache / "task_df.ld" / "index.json").exists())
169175
self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists())
170176
self.assertEqual(len(sample_dataset), 4)
171177

178+
sample_dataset.close()
179+
172180
def test_reuses_existing_cache_without_regeneration(self):
173181
"""Second call should reuse cached samples instead of recomputing."""
174-
cache_dir = self._task_cache_dir()
175-
_ = self.dataset.set_task(self.task, cache_dir=cache_dir)
182+
sample_dataset = self.dataset.set_task(self.task, cache_dir=self.cache_dir)
176183
self.assertEqual(self.task.call_count, 2)
177184

178185
with patch.object(
179186
self.task, "__call__", side_effect=AssertionError("Task should not rerun")
180187
):
181188
cached_dataset = self.dataset.set_task(
182-
self.task, cache_dir=cache_dir, cache_format="parquet"
189+
self.task, cache_dir=self.cache_dir, cache_format="parquet"
183190
)
184191

185192
self.assertEqual(len(cached_dataset), 4)
186193
self.assertEqual(self.task.call_count, 2)
187194

195+
sample_dataset.close()
196+
cached_dataset.close()
197+
198+
def test_tasks_with_diff_param_values_get_diff_caches(self):
199+
sample_dataset1 = self.dataset.set_task(MockTask(param=1))
200+
sample_dataset2 = self.dataset.set_task(MockTask(param=2))
201+
202+
task_params1 = json.dumps(
203+
{"call_count": 0, "param": 2},
204+
sort_keys=True,
205+
default=str
206+
)
207+
208+
task_params2 = json.dumps(
209+
{"call_count": 0, "param": 2},
210+
sort_keys=True,
211+
default=str
212+
)
213+
214+
task_cache1 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params1)}"
215+
task_cache2 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params2)}"
216+
217+
self.assertTrue(task_cache1.exists())
218+
self.assertTrue(task_cache2.exists())
219+
self.assertTrue((task_cache1 / "task_df.ld" / "index.json").exists())
220+
self.assertTrue((task_cache2 / "task_df.ld" / "index.json").exists())
221+
self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists())
222+
self.assertEqual(len(sample_dataset1), 4)
223+
self.assertEqual(len(sample_dataset2), 4)
224+
225+
sample_dataset1.close()
226+
sample_dataset2.close()
227+
188228

189229
if __name__ == "__main__":
190230
unittest.main()

0 commit comments

Comments
 (0)