Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pyhealth/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ def process(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
List of processed sample dictionaries.
"""
pass

class VocabMixin(ABC):
"""
Base class for feature processors that build a vocabulary.

Provides a common interface for accessing vocabulary-related information.
"""

@abstractmethod
def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
pass

@abstractmethod
def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
pass
20 changes: 17 additions & 3 deletions pyhealth/processors/deep_nested_sequence_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Iterable

import torch

Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self):
self._max_middle_len = 1 # Maximum length of middle sequences (e.g. visits)
self._max_inner_len = 1 # Maximum length of inner sequences (e.g. codes per visit)

def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
"""Build vocabulary and determine maximum sequence lengths.

Args:
Expand Down Expand Up @@ -86,6 +86,20 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
self._max_middle_len = max(1, max_middle_len)
self._max_inner_len = max(1, max_inner_len)

def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
self.code_vocab = {"<pad>": 0, "<unk>": -1}
for i, v in enumerate(vocab):
self.code_vocab[v] = i + 1

def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
vocab = list(set(self.code_vocab.keys()) & vocabularies)
self.code_vocab = {"<pad>": 0, "<unk>": -1}
for i, v in enumerate(vocab):
self.code_vocab[v] = i + 1

def process(self, value: List[List[List[Any]]]) -> torch.Tensor:
"""Process deep nested sequence into padded 3D tensor.

Expand Down Expand Up @@ -209,7 +223,7 @@ def __init__(self, forward_fill: bool = True):
self._max_inner_len = 1 # Maximum length of inner sequences (values per visit)
self.forward_fill = forward_fill

def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
"""Determine maximum sequence lengths.

Args:
Expand Down
12 changes: 12 additions & 0 deletions pyhealth/processors/nested_sequence_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
# (-1 because <unk> is already in vocab)
self.code_vocab["<unk>"] = len(self.code_vocab) - 1

def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
vocab = list(set(self.code_vocab.keys()) & vocabularies)
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def process(self, value: List[List[Any]]) -> torch.Tensor:
"""Process nested sequence into padded 2D tensor.

Expand Down
12 changes: 12 additions & 0 deletions pyhealth/processors/sequence_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def process(self, value: Any) -> torch.Tensor:
indices.append(self.code_vocab["<unk>"])

return torch.tensor(indices, dtype=torch.long)

def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
vocab = list(set(self.code_vocab.keys()) & vocabularies)
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def size(self):
return len(self.code_vocab)
Expand Down
16 changes: 14 additions & 2 deletions pyhealth/processors/stagenet_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch

from . import register_processor
from .base_processor import FeatureProcessor
from .base_processor import FeatureProcessor, VocabMixin


@register_processor("stagenet")
class StageNetProcessor(FeatureProcessor):
class StageNetProcessor(FeatureProcessor, VocabMixin):
"""
Feature processor for StageNet CODE inputs with coupled value/time data.

Expand Down Expand Up @@ -122,6 +122,18 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
# Since <unk> is already in the vocab dict, we use _next_index
self.code_vocab["<unk>"] = self._next_index

def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
vocab = list(set(self.code_vocab.keys()) & vocabularies)
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def process(
self, value: Tuple[Optional[List], List]
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
Expand Down
182 changes: 182 additions & 0 deletions tests/core/test_vocab_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import unittest
from typing import List, Dict, Any
from pyhealth.processors import (
SequenceProcessor,
StageNetProcessor,
NestedSequenceProcessor,
DeepNestedSequenceProcessor,
)

class TestVocabProcessors(unittest.TestCase):
"""
Test remove and retain methods for processors with vocabulary support.
covers: SequenceProcessor, StageNetProcessor, NestedSequenceProcessor, DeepNestedSequenceProcessor
"""

def test_sequence_processor_remove(self):
processor = SequenceProcessor()
samples = [
{"codes": ["A", "B", "C"]},
{"codes": ["D", "E"]},
]
processor.fit(samples, "codes")
original_vocab = set(processor.code_vocab.keys())
self.assertTrue({"A", "B", "C", "D", "E"}.issubset(original_vocab))

# Remove "A" and "B"
processor.remove({"A", "B"})
new_vocab = set(processor.code_vocab.keys())
self.assertNotIn("A", new_vocab)
self.assertNotIn("B", new_vocab)
self.assertIn("C", new_vocab)
self.assertIn("D", new_vocab)
self.assertIn("E", new_vocab)
self.assertIn("<unk>", new_vocab)
self.assertIn("<pad>", new_vocab)

# Verify processing still works (A and B become <unk>)
res = processor.process(["A", "C"])
unk_idx = processor.code_vocab["<unk>"]
c_idx = processor.code_vocab["C"]
self.assertEqual(res[0].item(), unk_idx)
self.assertEqual(res[1].item(), c_idx)

def test_sequence_processor_retain(self):
processor = SequenceProcessor()
samples = [
{"codes": ["A", "B", "C"]},
{"codes": ["D", "E"]},
]
processor.fit(samples, "codes")

# Retain "A" and "B"
processor.retain({"A", "B"})
new_vocab = set(processor.code_vocab.keys())
self.assertIn("A", new_vocab)
self.assertIn("B", new_vocab)
self.assertNotIn("C", new_vocab)
self.assertNotIn("D", new_vocab)
self.assertNotIn("E", new_vocab)
self.assertIn("<unk>", new_vocab)
self.assertIn("<pad>", new_vocab)

def test_stagenet_processor_remove(self):
processor = StageNetProcessor()
# Flat codes
samples = [
{"data": ([0.0, 1.0, 2.0], ["A", "B", "C"])},
{"data": ([0.0, 1.0], ["D", "E"])},
]
processor.fit(samples, "data")

processor.remove({"A", "B"})
new_vocab = set(processor.code_vocab.keys())
self.assertNotIn("A", new_vocab)
self.assertNotIn("B", new_vocab)
self.assertIn("C", new_vocab)
self.assertIn("D", new_vocab)
self.assertIn("E", new_vocab)

# Test processing
time, res = processor.process(([0.0, 1.0], ["A", "C"]))
unk_idx = processor.code_vocab["<unk>"]
c_idx = processor.code_vocab["C"]
self.assertEqual(res[0].item(), unk_idx)
self.assertEqual(res[1].item(), c_idx)

def test_stagenet_processor_retain(self):
processor = StageNetProcessor()
# Nested codes
samples = [
{"data": ([0.0, 1.0], [["A", "B"], ["C"]])},
{"data": ([0.0], [["D", "E"]])},
]
processor.fit(samples, "data")

processor.retain({"A", "B"})
new_vocab = set(processor.code_vocab.keys())
self.assertIn("A", new_vocab)
self.assertIn("B", new_vocab)
self.assertNotIn("C", new_vocab)
self.assertNotIn("D", new_vocab)

def test_nested_sequence_processor_remove(self):
processor = NestedSequenceProcessor()
samples = [
{"codes": [["A", "B"], ["C", "D"]]},
{"codes": [["E"]]},
]
processor.fit(samples, "codes")

processor.remove({"A", "B"})
new_vocab = set(processor.code_vocab.keys())
self.assertNotIn("A", new_vocab)
self.assertNotIn("B", new_vocab)
self.assertIn("C", new_vocab)
self.assertIn("D", new_vocab)
self.assertIn("E", new_vocab)

res = processor.process([["A", "C"]])
unk_idx = processor.code_vocab["<unk>"]
c_idx = processor.code_vocab["C"]
# res shape (1, max_inner_len)
# First code in first visit should be unk, second C
# Note: processor padds to max_inner_len
visit = res[0]
self.assertEqual(visit[0].item(), unk_idx)
self.assertEqual(visit[1].item(), c_idx)

def test_nested_sequence_processor_retain(self):
processor = NestedSequenceProcessor()
samples = [
{"codes": [["A", "B"], ["C", "D"]]},
{"codes": [["E"]]},
]
processor.fit(samples, "codes")

processor.retain({"E"})
new_vocab = set(processor.code_vocab.keys())
self.assertIn("E", new_vocab)
self.assertNotIn("A", new_vocab)
self.assertNotIn("B", new_vocab)
self.assertNotIn("C", new_vocab)
self.assertNotIn("D", new_vocab)

def test_deep_nested_sequence_processor_remove(self):
processor = DeepNestedSequenceProcessor()
samples = [
{"codes": [[["A", "B"], ["C"]], [["D"]]]},
]
processor.fit(samples, "codes")

processor.remove({"A"})
new_vocab = set(processor.code_vocab.keys())
self.assertNotIn("A", new_vocab)
self.assertIn("B", new_vocab)
self.assertIn("C", new_vocab)
self.assertIn("D", new_vocab)

# Test process
# Input [[[A]]] -> [[[<unk>]]] (padded)
res = processor.process([[["A"]]])
unk_idx = processor.code_vocab["<unk>"]
# res shape (1, max_visits, max_codes)
# first group, first visit, first code
self.assertEqual(res[0, 0, 0].item(), unk_idx)

def test_deep_nested_sequence_processor_retain(self):
processor = DeepNestedSequenceProcessor()
samples = [
{"codes": [[["A", "B"], ["C"]], [["D"]]]},
]
processor.fit(samples, "codes")

processor.retain({"A"})
new_vocab = set(processor.code_vocab.keys())
self.assertIn("A", new_vocab)
self.assertNotIn("B", new_vocab)
self.assertNotIn("C", new_vocab)
self.assertNotIn("D", new_vocab)

if __name__ == "__main__":
unittest.main()