diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 823fcafec..41697cfd4 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -92,3 +92,20 @@ def process(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: List of processed sample dictionaries. """ pass + +class VocabMixin(ABC): + """ + Base class for feature processors that build a vocabulary. + + Provides a common interface for accessing vocabulary-related information. + """ + + @abstractmethod + def remove(self, vocabularies: set[str]): + """Remove specified vocabularies from the processor.""" + pass + + @abstractmethod + def retain(self, vocabularies: set[str]): + """Retain only the specified vocabularies in the processor.""" + pass \ No newline at end of file diff --git a/pyhealth/processors/deep_nested_sequence_processor.py b/pyhealth/processors/deep_nested_sequence_processor.py index 53fb717f3..333d86e9d 100644 --- a/pyhealth/processors/deep_nested_sequence_processor.py +++ b/pyhealth/processors/deep_nested_sequence_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Iterable import torch @@ -51,7 +51,7 @@ def __init__(self): self._max_middle_len = 1 # Maximum length of middle sequences (e.g. visits) self._max_inner_len = 1 # Maximum length of inner sequences (e.g. codes per visit) - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Build vocabulary and determine maximum sequence lengths. Args: @@ -86,6 +86,20 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None: self._max_middle_len = max(1, max_middle_len) self._max_inner_len = max(1, max_inner_len) + def remove(self, vocabularies: set[str]): + """Remove specified vocabularies from the processor.""" + vocab = list(set(self.code_vocab.keys()) - vocabularies - {"", ""}) + self.code_vocab = {"": 0, "": -1} + for i, v in enumerate(vocab): + self.code_vocab[v] = i + 1 + + def retain(self, vocabularies: set[str]): + """Retain only the specified vocabularies in the processor.""" + vocab = list(set(self.code_vocab.keys()) & vocabularies) + self.code_vocab = {"": 0, "": -1} + for i, v in enumerate(vocab): + self.code_vocab[v] = i + 1 + def process(self, value: List[List[List[Any]]]) -> torch.Tensor: """Process deep nested sequence into padded 3D tensor. @@ -209,7 +223,7 @@ def __init__(self, forward_fill: bool = True): self._max_inner_len = 1 # Maximum length of inner sequences (values per visit) self.forward_fill = forward_fill - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Determine maximum sequence lengths. Args: diff --git a/pyhealth/processors/nested_sequence_processor.py b/pyhealth/processors/nested_sequence_processor.py index 80a2567bc..a7593ec74 100644 --- a/pyhealth/processors/nested_sequence_processor.py +++ b/pyhealth/processors/nested_sequence_processor.py @@ -86,6 +86,18 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: # (-1 because is already in vocab) self.code_vocab[""] = len(self.code_vocab) - 1 + def remove(self, vocabularies: set[str]): + """Remove specified vocabularies from the processor.""" + vocab = list(set(self.code_vocab.keys()) - vocabularies - {"", ""}) + vocab = [""] + vocab + [""] + self.code_vocab = {v: i for i, v in enumerate(vocab)} + + def retain(self, vocabularies: set[str]): + """Retain only the specified vocabularies in the processor.""" + vocab = list(set(self.code_vocab.keys()) & vocabularies) + vocab = [""] + vocab + [""] + self.code_vocab = {v: i for i, v in enumerate(vocab)} + def process(self, value: List[List[Any]]) -> torch.Tensor: """Process nested sequence into padded 2D tensor. diff --git a/pyhealth/processors/sequence_processor.py b/pyhealth/processors/sequence_processor.py index 32bb9e2cf..e45af1c38 100644 --- a/pyhealth/processors/sequence_processor.py +++ b/pyhealth/processors/sequence_processor.py @@ -48,6 +48,18 @@ def process(self, value: Any) -> torch.Tensor: indices.append(self.code_vocab[""]) return torch.tensor(indices, dtype=torch.long) + + def remove(self, vocabularies: set[str]): + """Remove specified vocabularies from the processor.""" + vocab = list(set(self.code_vocab.keys()) - vocabularies - {"", ""}) + vocab = [""] + vocab + [""] + self.code_vocab = {v: i for i, v in enumerate(vocab)} + + def retain(self, vocabularies: set[str]): + """Retain only the specified vocabularies in the processor.""" + vocab = list(set(self.code_vocab.keys()) & vocabularies) + vocab = [""] + vocab + [""] + self.code_vocab = {v: i for i, v in enumerate(vocab)} def size(self): return len(self.code_vocab) diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index fc3bf3d01..a1272ce24 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -3,11 +3,11 @@ import torch from . import register_processor -from .base_processor import FeatureProcessor +from .base_processor import FeatureProcessor, VocabMixin @register_processor("stagenet") -class StageNetProcessor(FeatureProcessor): +class StageNetProcessor(FeatureProcessor, VocabMixin): """ Feature processor for StageNet CODE inputs with coupled value/time data. @@ -122,6 +122,18 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: # Since is already in the vocab dict, we use _next_index self.code_vocab[""] = self._next_index + def remove(self, vocabularies: set[str]): + """Remove specified vocabularies from the processor.""" + vocab = list(set(self.code_vocab.keys()) - vocabularies - {"", ""}) + vocab = [""] + vocab + [""] + self.code_vocab = {v: i for i, v in enumerate(vocab)} + + def retain(self, vocabularies: set[str]): + """Retain only the specified vocabularies in the processor.""" + vocab = list(set(self.code_vocab.keys()) & vocabularies) + vocab = [""] + vocab + [""] + self.code_vocab = {v: i for i, v in enumerate(vocab)} + def process( self, value: Tuple[Optional[List], List] ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: diff --git a/tests/core/test_vocab_processors.py b/tests/core/test_vocab_processors.py new file mode 100644 index 000000000..f53629916 --- /dev/null +++ b/tests/core/test_vocab_processors.py @@ -0,0 +1,182 @@ +import unittest +from typing import List, Dict, Any +from pyhealth.processors import ( + SequenceProcessor, + StageNetProcessor, + NestedSequenceProcessor, + DeepNestedSequenceProcessor, +) + +class TestVocabProcessors(unittest.TestCase): + """ + Test remove and retain methods for processors with vocabulary support. + covers: SequenceProcessor, StageNetProcessor, NestedSequenceProcessor, DeepNestedSequenceProcessor + """ + + def test_sequence_processor_remove(self): + processor = SequenceProcessor() + samples = [ + {"codes": ["A", "B", "C"]}, + {"codes": ["D", "E"]}, + ] + processor.fit(samples, "codes") + original_vocab = set(processor.code_vocab.keys()) + self.assertTrue({"A", "B", "C", "D", "E"}.issubset(original_vocab)) + + # Remove "A" and "B" + processor.remove({"A", "B"}) + new_vocab = set(processor.code_vocab.keys()) + self.assertNotIn("A", new_vocab) + self.assertNotIn("B", new_vocab) + self.assertIn("C", new_vocab) + self.assertIn("D", new_vocab) + self.assertIn("E", new_vocab) + self.assertIn("", new_vocab) + self.assertIn("", new_vocab) + + # Verify processing still works (A and B become ) + res = processor.process(["A", "C"]) + unk_idx = processor.code_vocab[""] + c_idx = processor.code_vocab["C"] + self.assertEqual(res[0].item(), unk_idx) + self.assertEqual(res[1].item(), c_idx) + + def test_sequence_processor_retain(self): + processor = SequenceProcessor() + samples = [ + {"codes": ["A", "B", "C"]}, + {"codes": ["D", "E"]}, + ] + processor.fit(samples, "codes") + + # Retain "A" and "B" + processor.retain({"A", "B"}) + new_vocab = set(processor.code_vocab.keys()) + self.assertIn("A", new_vocab) + self.assertIn("B", new_vocab) + self.assertNotIn("C", new_vocab) + self.assertNotIn("D", new_vocab) + self.assertNotIn("E", new_vocab) + self.assertIn("", new_vocab) + self.assertIn("", new_vocab) + + def test_stagenet_processor_remove(self): + processor = StageNetProcessor() + # Flat codes + samples = [ + {"data": ([0.0, 1.0, 2.0], ["A", "B", "C"])}, + {"data": ([0.0, 1.0], ["D", "E"])}, + ] + processor.fit(samples, "data") + + processor.remove({"A", "B"}) + new_vocab = set(processor.code_vocab.keys()) + self.assertNotIn("A", new_vocab) + self.assertNotIn("B", new_vocab) + self.assertIn("C", new_vocab) + self.assertIn("D", new_vocab) + self.assertIn("E", new_vocab) + + # Test processing + time, res = processor.process(([0.0, 1.0], ["A", "C"])) + unk_idx = processor.code_vocab[""] + c_idx = processor.code_vocab["C"] + self.assertEqual(res[0].item(), unk_idx) + self.assertEqual(res[1].item(), c_idx) + + def test_stagenet_processor_retain(self): + processor = StageNetProcessor() + # Nested codes + samples = [ + {"data": ([0.0, 1.0], [["A", "B"], ["C"]])}, + {"data": ([0.0], [["D", "E"]])}, + ] + processor.fit(samples, "data") + + processor.retain({"A", "B"}) + new_vocab = set(processor.code_vocab.keys()) + self.assertIn("A", new_vocab) + self.assertIn("B", new_vocab) + self.assertNotIn("C", new_vocab) + self.assertNotIn("D", new_vocab) + + def test_nested_sequence_processor_remove(self): + processor = NestedSequenceProcessor() + samples = [ + {"codes": [["A", "B"], ["C", "D"]]}, + {"codes": [["E"]]}, + ] + processor.fit(samples, "codes") + + processor.remove({"A", "B"}) + new_vocab = set(processor.code_vocab.keys()) + self.assertNotIn("A", new_vocab) + self.assertNotIn("B", new_vocab) + self.assertIn("C", new_vocab) + self.assertIn("D", new_vocab) + self.assertIn("E", new_vocab) + + res = processor.process([["A", "C"]]) + unk_idx = processor.code_vocab[""] + c_idx = processor.code_vocab["C"] + # res shape (1, max_inner_len) + # First code in first visit should be unk, second C + # Note: processor padds to max_inner_len + visit = res[0] + self.assertEqual(visit[0].item(), unk_idx) + self.assertEqual(visit[1].item(), c_idx) + + def test_nested_sequence_processor_retain(self): + processor = NestedSequenceProcessor() + samples = [ + {"codes": [["A", "B"], ["C", "D"]]}, + {"codes": [["E"]]}, + ] + processor.fit(samples, "codes") + + processor.retain({"E"}) + new_vocab = set(processor.code_vocab.keys()) + self.assertIn("E", new_vocab) + self.assertNotIn("A", new_vocab) + self.assertNotIn("B", new_vocab) + self.assertNotIn("C", new_vocab) + self.assertNotIn("D", new_vocab) + + def test_deep_nested_sequence_processor_remove(self): + processor = DeepNestedSequenceProcessor() + samples = [ + {"codes": [[["A", "B"], ["C"]], [["D"]]]}, + ] + processor.fit(samples, "codes") + + processor.remove({"A"}) + new_vocab = set(processor.code_vocab.keys()) + self.assertNotIn("A", new_vocab) + self.assertIn("B", new_vocab) + self.assertIn("C", new_vocab) + self.assertIn("D", new_vocab) + + # Test process + # Input [[[A]]] -> [[[]]] (padded) + res = processor.process([[["A"]]]) + unk_idx = processor.code_vocab[""] + # res shape (1, max_visits, max_codes) + # first group, first visit, first code + self.assertEqual(res[0, 0, 0].item(), unk_idx) + + def test_deep_nested_sequence_processor_retain(self): + processor = DeepNestedSequenceProcessor() + samples = [ + {"codes": [[["A", "B"], ["C"]], [["D"]]]}, + ] + processor.fit(samples, "codes") + + processor.retain({"A"}) + new_vocab = set(processor.code_vocab.keys()) + self.assertIn("A", new_vocab) + self.assertNotIn("B", new_vocab) + self.assertNotIn("C", new_vocab) + self.assertNotIn("D", new_vocab) + +if __name__ == "__main__": + unittest.main()