Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,30 +235,19 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None:
writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB")

dataset = litdata.StreamingDataset(str(task_df))
complete = 0
with open(f"{output_dir}/schema.pkl", "rb") as f:
metadata = pickle.load(f)

input_processors = metadata["input_processors"]
output_processors = metadata["output_processors"]

write_index = 0
for i in range(start_idx, end_idx):
transformed: Dict[str, Any] = {}
for key, value in pickle.loads(dataset[i]["sample"]).items():
if key in input_processors:
transformed[key] = input_processors[key].process(value)
elif key in output_processors:
transformed[key] = output_processors[key].process(value)
else:
transformed[key] = value
writer.add_item(write_index, transformed)
write_index += 1
complete += 1
builder = SampleBuilder.load(f"{output_dir}/schema.pkl")

if complete >= BATCH_SIZE:
progress.put(complete)
complete = 0
complete = 0
write_index = 0
for i in range(start_idx, end_idx):
transformed: Dict[str, Any] = builder.transform(dataset[i])
writer.add_item(write_index, transformed)
write_index += 1
complete += 1

if complete >= BATCH_SIZE:
progress.put(complete)
complete = 0

if complete > 0:
progress.put(complete)
Expand Down
57 changes: 54 additions & 3 deletions pyhealth/datasets/sample_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from litdata.utilities.train_test_split import deepcopy_dataset
import copy

from ..processors import get_processor
from ..processors import get_processor, IgnoreProcessor
from ..processors.base_processor import FeatureProcessor


Expand Down Expand Up @@ -191,8 +191,14 @@ def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]:
transformed: Dict[str, Any] = {}
for key, value in pickle.loads(sample["sample"]).items():
if key in self._input_processors:
# Skip ignored features
if isinstance(self._input_processors[key], IgnoreProcessor):
continue
transformed[key] = self._input_processors[key].process(value)
elif key in self._output_processors:
# Skip ignored features
if isinstance(self._output_processors[key], IgnoreProcessor):
continue
transformed[key] = self._output_processors[key].process(value)
else:
transformed[key] = value
Expand Down Expand Up @@ -221,6 +227,30 @@ def save(self, path: str) -> None:
with open(path, "wb") as f:
pickle.dump(metadata, f)

@staticmethod
def load(path: str) -> "SampleBuilder":
"""Load a SampleBuilder from a pickled metadata file.

Args:
path: Location of the pickled metadata file (commonly named `schema.pkl`).

Returns:
A SampleBuilder instance with loaded metadata.
"""
with open(path, "rb") as f:
metadata = pickle.load(f)

builder = SampleBuilder(
input_schema=metadata["input_schema"],
output_schema=metadata["output_schema"],
)
builder._input_processors = metadata["input_processors"]
builder._output_processors = metadata["output_processors"]
builder._patient_to_index = metadata["patient_to_index"]
builder._record_to_index = metadata["record_to_index"]
builder._fitted = True
return builder


class SampleDataset(litdata.StreamingDataset):
"""A streaming dataset that loads sample metadata and processors from disk.
Expand Down Expand Up @@ -276,10 +306,29 @@ def __init__(
self.output_schema = metadata["output_schema"]
self.input_processors = metadata["input_processors"]
self.output_processors = metadata["output_processors"]
self._remove_ignored_processors()

self.patient_to_index = metadata["patient_to_index"]
self.record_to_index = metadata["record_to_index"]

def _remove_ignored_processors(self):
"""Remove any processors that are IgnoreProcessor instances."""
for key in [
key
for key, proc in self.input_processors.items()
if isinstance(proc, IgnoreProcessor)
]:
del self.input_processors[key]
del self.input_schema[key]

for key in [
key
for key, proc in self.output_processors.items()
if isinstance(proc, IgnoreProcessor)
]:
del self.output_processors[key]
del self.output_schema[key]

def __str__(self) -> str:
"""Returns a string representation of the dataset.

Expand Down Expand Up @@ -356,12 +405,12 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset":
new_dataset.reset()

return new_dataset

def close(self) -> None:
"""Cleans up any temporary directories used by the dataset."""
if self.input_dir.path is not None and Path(self.input_dir.path).exists():
shutil.rmtree(self.input_dir.path)

# --------------------------------------------------------------
# Context manager support
# --------------------------------------------------------------
Expand Down Expand Up @@ -426,6 +475,7 @@ def __init__(
self.output_schema = builder.output_schema
self.input_processors = builder.input_processors
self.output_processors = builder.output_processors
self._remove_ignored_processors()

self.patient_to_index = builder.patient_to_index
self.record_to_index = builder.record_to_index
Expand Down Expand Up @@ -482,6 +532,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset:
def close(self) -> None:
pass # No temporary directories to clean up for in-memory dataset


def create_sample_dataset(
samples: List[Dict[str, Any]],
input_schema: Dict[str, Any],
Expand Down
1 change: 1 addition & 0 deletions pyhealth/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def get_processor(name: str):
from .text_processor import TextProcessor
from .timeseries_processor import TimeseriesProcessor
from .audio_processor import AudioProcessor
from .ignore_processor import IgnoreProcessor

# Expose public API
__all__ = [
Expand Down
26 changes: 26 additions & 0 deletions pyhealth/processors/ignore_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Any, Dict, Iterable
from . import register_processor
from .base_processor import FeatureProcessor


@register_processor("ignore")
class IgnoreProcessor(FeatureProcessor):
"""A special feature processor that marks a feature to be ignored during processing.
"""

def __init__(self) -> None:
pass

def process(self, value: Any) -> Any:
"""This method is intentionally not implemented.

Args:
value: Any raw field value.

Raises:
NotImplementedError: Always raised to indicate this processor ignores the field.
"""
raise NotImplementedError("IgnoreProcessor does not implement process method.")

def __repr__(self) -> str:
return (f"IgnoreProcessor()")
92 changes: 92 additions & 0 deletions tests/core/test_ignore_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import unittest
import shutil
import tempfile
from pathlib import Path
import pandas as pd
import dask.dataframe as dd

from pyhealth.datasets.base_dataset import BaseDataset
from pyhealth.tasks.base_task import BaseTask
from pyhealth.processors.ignore_processor import IgnoreProcessor
from pyhealth.processors import RawProcessor

class MockTask(BaseTask):
task_name = "test_task"
input_schema = {
"keep_field": "raw",
"ignore_field": "raw"
}
output_schema = {"label": "binary"}

def __call__(self, patient):
return [{
"keep_field": "keep_val",
"ignore_field": "ignore_val",
"label": 0 if patient.patient_id == "1" else 1,
"patient_id": patient.patient_id
}]

class MockDataset(BaseDataset):
def __init__(self, root, **kwargs):
super().__init__(root=root, tables=[], **kwargs)

def load_data(self):
return dd.from_pandas(
pd.DataFrame({
"patient_id": ["1", "2"],
"event_type": ["visit", "visit"],
"timestamp": [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")],
}),
npartitions=1
)

class TestIgnoreProcessor(unittest.TestCase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
self.root = self.tmp_dir
self.dataset = MockDataset(root=self.root)

def tearDown(self):
shutil.rmtree(self.tmp_dir)

def test_ignore_processor_with_set_task(self):
task = MockTask()

# 1. Normal set_task
ds1 = self.dataset.set_task(task)
self.assertIn("ignore_field", ds1.input_schema)

# Check data
# We need to access the first sample.
# Since SampleDataset is a StreamingDataset, we can index it or iterate.
sample1 = ds1[0]
self.assertIn("ignore_field", sample1)
self.assertEqual(sample1["ignore_field"], "ignore_val")

# 2. set_task with ignore processor
# We MUST provide processors for ALL fields to avoid re-population logic in SampleBuilder
ds2 = self.dataset.set_task(
task,
input_processors={
"keep_field": RawProcessor(),
"ignore_field": IgnoreProcessor()
}
)

# Expectation: "ignore_field" should be removed from input_schema of the dataset
# This is what the user asked for: "result should be the input_schema & input_processors does not exists"

# Note: Depending on current implementation, this might fail.
self.assertNotIn("ignore_field", ds2.input_schema)
self.assertNotIn("ignore_field", ds2.input_processors)

sample2 = ds2[0]
# Expectation: "ignore_field" should NOT be in the sample data
self.assertNotIn("ignore_field", sample2)

# 'keep_field' should still be there
self.assertIn("keep_field", sample2)
self.assertEqual(sample2["keep_field"], "keep_val")

if __name__ == "__main__":
unittest.main()