Skip to content

Commit 6cfa99e

Browse files
authored
Add IgnoreProcessor to support remove a key in schema. (#776)
* ensure code reusibility * Fix incorrect unpickle * add ignore processor * Do not process ignore processor * Add test, fix bugs * Update docstring
1 parent c63aae4 commit 6cfa99e

File tree

7 files changed

+209
-26
lines changed

7 files changed

+209
-26
lines changed

docs/api/processors.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Available Processors
4444
- ``StageNetProcessor``: For StageNet model with lab measurements
4545
- ``StageNetTensorProcessor``: Tensor processing for StageNet
4646
- ``MultiHotProcessor``: For multi-hot encoding
47+
- ``IgnoreProcessor``: A special feature processor that marks a feature to be ignored.
4748

4849
Usage Examples
4950
--------------
@@ -460,6 +461,7 @@ API Reference
460461
processors/pyhealth.processors.TimeseriesProcessor
461462
processors/pyhealth.processors.TensorProcessor
462463
processors/pyhealth.processors.RawProcessor
464+
processors/pyhealth.processors.IgnoreProcessor
463465
processors/pyhealth.processors.MultiHotProcessor
464466
processors/pyhealth.processors.StageNetProcessor
465467
processors/pyhealth.processors.StageNetTensorProcessor
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
pyhealth.processors.IgnoreProcessor
2+
======================================
3+
4+
Processor to ignore a feature.
5+
6+
.. autoclass:: pyhealth.processors.IgnoreProcessor
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:

pyhealth/datasets/base_dataset.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -235,30 +235,19 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None:
235235
writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB")
236236

237237
dataset = litdata.StreamingDataset(str(task_df))
238-
complete = 0
239-
with open(f"{output_dir}/schema.pkl", "rb") as f:
240-
metadata = pickle.load(f)
241-
242-
input_processors = metadata["input_processors"]
243-
output_processors = metadata["output_processors"]
244-
245-
write_index = 0
246-
for i in range(start_idx, end_idx):
247-
transformed: Dict[str, Any] = {}
248-
for key, value in pickle.loads(dataset[i]["sample"]).items():
249-
if key in input_processors:
250-
transformed[key] = input_processors[key].process(value)
251-
elif key in output_processors:
252-
transformed[key] = output_processors[key].process(value)
253-
else:
254-
transformed[key] = value
255-
writer.add_item(write_index, transformed)
256-
write_index += 1
257-
complete += 1
238+
builder = SampleBuilder.load(f"{output_dir}/schema.pkl")
258239

259-
if complete >= BATCH_SIZE:
260-
progress.put(complete)
261-
complete = 0
240+
complete = 0
241+
write_index = 0
242+
for i in range(start_idx, end_idx):
243+
transformed: Dict[str, Any] = builder.transform(dataset[i])
244+
writer.add_item(write_index, transformed)
245+
write_index += 1
246+
complete += 1
247+
248+
if complete >= BATCH_SIZE:
249+
progress.put(complete)
250+
complete = 0
262251

263252
if complete > 0:
264253
progress.put(complete)

pyhealth/datasets/sample_dataset.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from litdata.utilities.train_test_split import deepcopy_dataset
1212
import copy
1313

14-
from ..processors import get_processor
14+
from ..processors import get_processor, IgnoreProcessor
1515
from ..processors.base_processor import FeatureProcessor
1616

1717

@@ -191,8 +191,14 @@ def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]:
191191
transformed: Dict[str, Any] = {}
192192
for key, value in pickle.loads(sample["sample"]).items():
193193
if key in self._input_processors:
194+
# Skip ignored features
195+
if isinstance(self._input_processors[key], IgnoreProcessor):
196+
continue
194197
transformed[key] = self._input_processors[key].process(value)
195198
elif key in self._output_processors:
199+
# Skip ignored features
200+
if isinstance(self._output_processors[key], IgnoreProcessor):
201+
continue
196202
transformed[key] = self._output_processors[key].process(value)
197203
else:
198204
transformed[key] = value
@@ -221,6 +227,30 @@ def save(self, path: str) -> None:
221227
with open(path, "wb") as f:
222228
pickle.dump(metadata, f)
223229

230+
@staticmethod
231+
def load(path: str) -> "SampleBuilder":
232+
"""Load a SampleBuilder from a pickled metadata file.
233+
234+
Args:
235+
path: Location of the pickled metadata file (commonly named `schema.pkl`).
236+
237+
Returns:
238+
A SampleBuilder instance with loaded metadata.
239+
"""
240+
with open(path, "rb") as f:
241+
metadata = pickle.load(f)
242+
243+
builder = SampleBuilder(
244+
input_schema=metadata["input_schema"],
245+
output_schema=metadata["output_schema"],
246+
)
247+
builder._input_processors = metadata["input_processors"]
248+
builder._output_processors = metadata["output_processors"]
249+
builder._patient_to_index = metadata["patient_to_index"]
250+
builder._record_to_index = metadata["record_to_index"]
251+
builder._fitted = True
252+
return builder
253+
224254

225255
class SampleDataset(litdata.StreamingDataset):
226256
"""A streaming dataset that loads sample metadata and processors from disk.
@@ -276,10 +306,29 @@ def __init__(
276306
self.output_schema = metadata["output_schema"]
277307
self.input_processors = metadata["input_processors"]
278308
self.output_processors = metadata["output_processors"]
309+
self._remove_ignored_processors()
279310

280311
self.patient_to_index = metadata["patient_to_index"]
281312
self.record_to_index = metadata["record_to_index"]
282313

314+
def _remove_ignored_processors(self):
315+
"""Remove any processors that are IgnoreProcessor instances."""
316+
for key in [
317+
key
318+
for key, proc in self.input_processors.items()
319+
if isinstance(proc, IgnoreProcessor)
320+
]:
321+
del self.input_processors[key]
322+
del self.input_schema[key]
323+
324+
for key in [
325+
key
326+
for key, proc in self.output_processors.items()
327+
if isinstance(proc, IgnoreProcessor)
328+
]:
329+
del self.output_processors[key]
330+
del self.output_schema[key]
331+
283332
def __str__(self) -> str:
284333
"""Returns a string representation of the dataset.
285334
@@ -356,12 +405,12 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset":
356405
new_dataset.reset()
357406

358407
return new_dataset
359-
408+
360409
def close(self) -> None:
361410
"""Cleans up any temporary directories used by the dataset."""
362411
if self.input_dir.path is not None and Path(self.input_dir.path).exists():
363412
shutil.rmtree(self.input_dir.path)
364-
413+
365414
# --------------------------------------------------------------
366415
# Context manager support
367416
# --------------------------------------------------------------
@@ -426,6 +475,7 @@ def __init__(
426475
self.output_schema = builder.output_schema
427476
self.input_processors = builder.input_processors
428477
self.output_processors = builder.output_processors
478+
self._remove_ignored_processors()
429479

430480
self.patient_to_index = builder.patient_to_index
431481
self.record_to_index = builder.record_to_index
@@ -482,6 +532,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset:
482532
def close(self) -> None:
483533
pass # No temporary directories to clean up for in-memory dataset
484534

535+
485536
def create_sample_dataset(
486537
samples: List[Dict[str, Any]],
487538
input_schema: Dict[str, Any],

pyhealth/processors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def get_processor(name: str):
4545
from .text_processor import TextProcessor
4646
from .timeseries_processor import TimeseriesProcessor
4747
from .audio_processor import AudioProcessor
48+
from .ignore_processor import IgnoreProcessor
4849

4950
# Expose public API
5051
__all__ = [
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Any, Dict, Iterable
2+
from . import register_processor
3+
from .base_processor import FeatureProcessor
4+
5+
6+
@register_processor("ignore")
7+
class IgnoreProcessor(FeatureProcessor):
8+
"""A special feature processor that marks a feature to be ignored during processing.
9+
10+
This processor is useful when you want to remove a specific feature from the dataset
11+
after the task function processing, but without modifying the task function itself.
12+
13+
Example:
14+
>>> from pyhealth.processors import IgnoreProcessor
15+
>>> # Assume we have a task that outputs "feature1" and "feature2"
16+
>>> # We want to remove "feature2" from the final dataset
17+
>>> dataset.set_task(task, input_processors={
18+
... "feature1": SequenceProcessor(code_to_index),
19+
... "feature2": IgnoreProcessor()
20+
... })
21+
>>> # Now samples in dataset will only contain "feature1"
22+
"""
23+
24+
def __init__(self) -> None:
25+
pass
26+
27+
def process(self, value: Any) -> Any:
28+
"""This method is intentionally not implemented.
29+
30+
Args:
31+
value: Any raw field value.
32+
33+
Raises:
34+
NotImplementedError: Always raised to indicate this processor ignores the field.
35+
"""
36+
raise NotImplementedError("IgnoreProcessor does not implement process method.")
37+
38+
def __repr__(self) -> str:
39+
return (f"IgnoreProcessor()")
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import unittest
2+
import shutil
3+
import tempfile
4+
from pathlib import Path
5+
import pandas as pd
6+
import dask.dataframe as dd
7+
8+
from pyhealth.datasets.base_dataset import BaseDataset
9+
from pyhealth.tasks.base_task import BaseTask
10+
from pyhealth.processors.ignore_processor import IgnoreProcessor
11+
from pyhealth.processors import RawProcessor
12+
13+
class MockTask(BaseTask):
14+
task_name = "test_task"
15+
input_schema = {
16+
"keep_field": "raw",
17+
"ignore_field": "raw"
18+
}
19+
output_schema = {"label": "binary"}
20+
21+
def __call__(self, patient):
22+
return [{
23+
"keep_field": "keep_val",
24+
"ignore_field": "ignore_val",
25+
"label": 0 if patient.patient_id == "1" else 1,
26+
"patient_id": patient.patient_id
27+
}]
28+
29+
class MockDataset(BaseDataset):
30+
def __init__(self, root, **kwargs):
31+
super().__init__(root=root, tables=[], **kwargs)
32+
33+
def load_data(self):
34+
return dd.from_pandas(
35+
pd.DataFrame({
36+
"patient_id": ["1", "2"],
37+
"event_type": ["visit", "visit"],
38+
"timestamp": [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")],
39+
}),
40+
npartitions=1
41+
)
42+
43+
class TestIgnoreProcessor(unittest.TestCase):
44+
def setUp(self):
45+
self.tmp_dir = tempfile.mkdtemp()
46+
self.root = self.tmp_dir
47+
self.dataset = MockDataset(root=self.root)
48+
49+
def tearDown(self):
50+
shutil.rmtree(self.tmp_dir)
51+
52+
def test_ignore_processor_with_set_task(self):
53+
task = MockTask()
54+
55+
# 1. Normal set_task
56+
ds1 = self.dataset.set_task(task)
57+
self.assertIn("ignore_field", ds1.input_schema)
58+
59+
# Check data
60+
# We need to access the first sample.
61+
# Since SampleDataset is a StreamingDataset, we can index it or iterate.
62+
sample1 = ds1[0]
63+
self.assertIn("ignore_field", sample1)
64+
self.assertEqual(sample1["ignore_field"], "ignore_val")
65+
66+
# 2. set_task with ignore processor
67+
# We MUST provide processors for ALL fields to avoid re-population logic in SampleBuilder
68+
ds2 = self.dataset.set_task(
69+
task,
70+
input_processors={
71+
"keep_field": RawProcessor(),
72+
"ignore_field": IgnoreProcessor()
73+
}
74+
)
75+
76+
# Expectation: "ignore_field" should be removed from input_schema of the dataset
77+
# This is what the user asked for: "result should be the input_schema & input_processors does not exists"
78+
79+
# Note: Depending on current implementation, this might fail.
80+
self.assertNotIn("ignore_field", ds2.input_schema)
81+
self.assertNotIn("ignore_field", ds2.input_processors)
82+
83+
sample2 = ds2[0]
84+
# Expectation: "ignore_field" should NOT be in the sample data
85+
self.assertNotIn("ignore_field", sample2)
86+
87+
# 'keep_field' should still be there
88+
self.assertIn("keep_field", sample2)
89+
self.assertEqual(sample2["keep_field"], "keep_val")
90+
91+
if __name__ == "__main__":
92+
unittest.main()

0 commit comments

Comments
 (0)