diff --git a/docs/api/processors.rst b/docs/api/processors.rst index dcdfb2861..25de2fece 100644 --- a/docs/api/processors.rst +++ b/docs/api/processors.rst @@ -44,6 +44,7 @@ Available Processors - ``StageNetProcessor``: For StageNet model with lab measurements - ``StageNetTensorProcessor``: Tensor processing for StageNet - ``MultiHotProcessor``: For multi-hot encoding +- ``IgnoreProcessor``: A special feature processor that marks a feature to be ignored. Usage Examples -------------- @@ -460,6 +461,7 @@ API Reference processors/pyhealth.processors.TimeseriesProcessor processors/pyhealth.processors.TensorProcessor processors/pyhealth.processors.RawProcessor + processors/pyhealth.processors.IgnoreProcessor processors/pyhealth.processors.MultiHotProcessor processors/pyhealth.processors.StageNetProcessor processors/pyhealth.processors.StageNetTensorProcessor \ No newline at end of file diff --git a/docs/api/processors/pyhealth.processors.IgnoreProcessor.rst b/docs/api/processors/pyhealth.processors.IgnoreProcessor.rst new file mode 100644 index 000000000..aae6c3228 --- /dev/null +++ b/docs/api/processors/pyhealth.processors.IgnoreProcessor.rst @@ -0,0 +1,9 @@ +pyhealth.processors.IgnoreProcessor +====================================== + +Processor to ignore a feature. + +.. autoclass:: pyhealth.processors.IgnoreProcessor + :members: + :undoc-members: + :show-inheritance: diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ec721e8c6..91a1c95af 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -235,30 +235,19 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") dataset = litdata.StreamingDataset(str(task_df)) - complete = 0 - with open(f"{output_dir}/schema.pkl", "rb") as f: - metadata = pickle.load(f) - - input_processors = metadata["input_processors"] - output_processors = metadata["output_processors"] - - write_index = 0 - for i in range(start_idx, end_idx): - transformed: Dict[str, Any] = {} - for key, value in pickle.loads(dataset[i]["sample"]).items(): - if key in input_processors: - transformed[key] = input_processors[key].process(value) - elif key in output_processors: - transformed[key] = output_processors[key].process(value) - else: - transformed[key] = value - writer.add_item(write_index, transformed) - write_index += 1 - complete += 1 + builder = SampleBuilder.load(f"{output_dir}/schema.pkl") - if complete >= BATCH_SIZE: - progress.put(complete) - complete = 0 + complete = 0 + write_index = 0 + for i in range(start_idx, end_idx): + transformed: Dict[str, Any] = builder.transform(dataset[i]) + writer.add_item(write_index, transformed) + write_index += 1 + complete += 1 + + if complete >= BATCH_SIZE: + progress.put(complete) + complete = 0 if complete > 0: progress.put(complete) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 906b06b8b..30f4dd12e 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -11,7 +11,7 @@ from litdata.utilities.train_test_split import deepcopy_dataset import copy -from ..processors import get_processor +from ..processors import get_processor, IgnoreProcessor from ..processors.base_processor import FeatureProcessor @@ -191,8 +191,14 @@ def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]: transformed: Dict[str, Any] = {} for key, value in pickle.loads(sample["sample"]).items(): if key in self._input_processors: + # Skip ignored features + if isinstance(self._input_processors[key], IgnoreProcessor): + continue transformed[key] = self._input_processors[key].process(value) elif key in self._output_processors: + # Skip ignored features + if isinstance(self._output_processors[key], IgnoreProcessor): + continue transformed[key] = self._output_processors[key].process(value) else: transformed[key] = value @@ -221,6 +227,30 @@ def save(self, path: str) -> None: with open(path, "wb") as f: pickle.dump(metadata, f) + @staticmethod + def load(path: str) -> "SampleBuilder": + """Load a SampleBuilder from a pickled metadata file. + + Args: + path: Location of the pickled metadata file (commonly named `schema.pkl`). + + Returns: + A SampleBuilder instance with loaded metadata. + """ + with open(path, "rb") as f: + metadata = pickle.load(f) + + builder = SampleBuilder( + input_schema=metadata["input_schema"], + output_schema=metadata["output_schema"], + ) + builder._input_processors = metadata["input_processors"] + builder._output_processors = metadata["output_processors"] + builder._patient_to_index = metadata["patient_to_index"] + builder._record_to_index = metadata["record_to_index"] + builder._fitted = True + return builder + class SampleDataset(litdata.StreamingDataset): """A streaming dataset that loads sample metadata and processors from disk. @@ -276,10 +306,29 @@ def __init__( self.output_schema = metadata["output_schema"] self.input_processors = metadata["input_processors"] self.output_processors = metadata["output_processors"] + self._remove_ignored_processors() self.patient_to_index = metadata["patient_to_index"] self.record_to_index = metadata["record_to_index"] + def _remove_ignored_processors(self): + """Remove any processors that are IgnoreProcessor instances.""" + for key in [ + key + for key, proc in self.input_processors.items() + if isinstance(proc, IgnoreProcessor) + ]: + del self.input_processors[key] + del self.input_schema[key] + + for key in [ + key + for key, proc in self.output_processors.items() + if isinstance(proc, IgnoreProcessor) + ]: + del self.output_processors[key] + del self.output_schema[key] + def __str__(self) -> str: """Returns a string representation of the dataset. @@ -356,12 +405,12 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": new_dataset.reset() return new_dataset - + def close(self) -> None: """Cleans up any temporary directories used by the dataset.""" if self.input_dir.path is not None and Path(self.input_dir.path).exists(): shutil.rmtree(self.input_dir.path) - + # -------------------------------------------------------------- # Context manager support # -------------------------------------------------------------- @@ -426,6 +475,7 @@ def __init__( self.output_schema = builder.output_schema self.input_processors = builder.input_processors self.output_processors = builder.output_processors + self._remove_ignored_processors() self.patient_to_index = builder.patient_to_index self.record_to_index = builder.record_to_index @@ -482,6 +532,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: def close(self) -> None: pass # No temporary directories to clean up for in-memory dataset + def create_sample_dataset( samples: List[Dict[str, Any]], input_schema: Dict[str, Any], diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index 1b4e15cac..15512c2d7 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -45,6 +45,7 @@ def get_processor(name: str): from .text_processor import TextProcessor from .timeseries_processor import TimeseriesProcessor from .audio_processor import AudioProcessor +from .ignore_processor import IgnoreProcessor # Expose public API __all__ = [ diff --git a/pyhealth/processors/ignore_processor.py b/pyhealth/processors/ignore_processor.py new file mode 100644 index 000000000..2f8c35c31 --- /dev/null +++ b/pyhealth/processors/ignore_processor.py @@ -0,0 +1,39 @@ +from typing import Any, Dict, Iterable +from . import register_processor +from .base_processor import FeatureProcessor + + +@register_processor("ignore") +class IgnoreProcessor(FeatureProcessor): + """A special feature processor that marks a feature to be ignored during processing. + + This processor is useful when you want to remove a specific feature from the dataset + after the task function processing, but without modifying the task function itself. + + Example: + >>> from pyhealth.processors import IgnoreProcessor + >>> # Assume we have a task that outputs "feature1" and "feature2" + >>> # We want to remove "feature2" from the final dataset + >>> dataset.set_task(task, input_processors={ + ... "feature1": SequenceProcessor(code_to_index), + ... "feature2": IgnoreProcessor() + ... }) + >>> # Now samples in dataset will only contain "feature1" + """ + + def __init__(self) -> None: + pass + + def process(self, value: Any) -> Any: + """This method is intentionally not implemented. + + Args: + value: Any raw field value. + + Raises: + NotImplementedError: Always raised to indicate this processor ignores the field. + """ + raise NotImplementedError("IgnoreProcessor does not implement process method.") + + def __repr__(self) -> str: + return (f"IgnoreProcessor()") diff --git a/tests/core/test_ignore_processor.py b/tests/core/test_ignore_processor.py new file mode 100644 index 000000000..352f35595 --- /dev/null +++ b/tests/core/test_ignore_processor.py @@ -0,0 +1,92 @@ +import unittest +import shutil +import tempfile +from pathlib import Path +import pandas as pd +import dask.dataframe as dd + +from pyhealth.datasets.base_dataset import BaseDataset +from pyhealth.tasks.base_task import BaseTask +from pyhealth.processors.ignore_processor import IgnoreProcessor +from pyhealth.processors import RawProcessor + +class MockTask(BaseTask): + task_name = "test_task" + input_schema = { + "keep_field": "raw", + "ignore_field": "raw" + } + output_schema = {"label": "binary"} + + def __call__(self, patient): + return [{ + "keep_field": "keep_val", + "ignore_field": "ignore_val", + "label": 0 if patient.patient_id == "1" else 1, + "patient_id": patient.patient_id + }] + +class MockDataset(BaseDataset): + def __init__(self, root, **kwargs): + super().__init__(root=root, tables=[], **kwargs) + + def load_data(self): + return dd.from_pandas( + pd.DataFrame({ + "patient_id": ["1", "2"], + "event_type": ["visit", "visit"], + "timestamp": [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")], + }), + npartitions=1 + ) + +class TestIgnoreProcessor(unittest.TestCase): + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + self.root = self.tmp_dir + self.dataset = MockDataset(root=self.root) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + + def test_ignore_processor_with_set_task(self): + task = MockTask() + + # 1. Normal set_task + ds1 = self.dataset.set_task(task) + self.assertIn("ignore_field", ds1.input_schema) + + # Check data + # We need to access the first sample. + # Since SampleDataset is a StreamingDataset, we can index it or iterate. + sample1 = ds1[0] + self.assertIn("ignore_field", sample1) + self.assertEqual(sample1["ignore_field"], "ignore_val") + + # 2. set_task with ignore processor + # We MUST provide processors for ALL fields to avoid re-population logic in SampleBuilder + ds2 = self.dataset.set_task( + task, + input_processors={ + "keep_field": RawProcessor(), + "ignore_field": IgnoreProcessor() + } + ) + + # Expectation: "ignore_field" should be removed from input_schema of the dataset + # This is what the user asked for: "result should be the input_schema & input_processors does not exists" + + # Note: Depending on current implementation, this might fail. + self.assertNotIn("ignore_field", ds2.input_schema) + self.assertNotIn("ignore_field", ds2.input_processors) + + sample2 = ds2[0] + # Expectation: "ignore_field" should NOT be in the sample data + self.assertNotIn("ignore_field", sample2) + + # 'keep_field' should still be there + self.assertIn("keep_field", sample2) + self.assertEqual(sample2["keep_field"], "keep_val") + +if __name__ == "__main__": + unittest.main()