Skip to content

Commit bf986e0

Browse files
authored
Allow change vocabs for processor that take vocabs (#778)
* Allow change vocab for a processor * Add test * Add `add` method to vocab processors
1 parent 470f89c commit bf986e0

File tree

6 files changed

+370
-11
lines changed

6 files changed

+370
-11
lines changed

pyhealth/processors/base_processor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,25 @@ def process(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
9292
List of processed sample dictionaries.
9393
"""
9494
pass
95+
96+
class VocabMixin(ABC):
97+
"""
98+
Base class for feature processors that build a vocabulary.
99+
100+
Provides a common interface for accessing vocabulary-related information.
101+
"""
102+
103+
@abstractmethod
104+
def remove(self, vocabularies: set[str]):
105+
"""Remove specified vocabularies from the processor."""
106+
pass
107+
108+
@abstractmethod
109+
def retain(self, vocabularies: set[str]):
110+
"""Retain only the specified vocabularies in the processor."""
111+
pass
112+
113+
@abstractmethod
114+
def add(self, vocabularies: set[str]):
115+
"""Add specified vocabularies to the processor."""
116+
pass

pyhealth/processors/deep_nested_sequence_processor.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from typing import Any, Dict, List
1+
from typing import Any, Dict, List, Iterable
22

33
import torch
44

55
from . import register_processor
6-
from .base_processor import FeatureProcessor
6+
from .base_processor import FeatureProcessor, VocabMixin
77

88

99
@register_processor("deep_nested_sequence")
10-
class DeepNestedSequenceProcessor(FeatureProcessor):
10+
class DeepNestedSequenceProcessor(FeatureProcessor, VocabMixin):
1111
"""
1212
Feature processor for deeply nested categorical sequences with vocabulary.
1313
@@ -51,7 +51,7 @@ def __init__(self):
5151
self._max_middle_len = 1 # Maximum length of middle sequences (e.g. visits)
5252
self._max_inner_len = 1 # Maximum length of inner sequences (e.g. codes per visit)
5353

54-
def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
54+
def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
5555
"""Build vocabulary and determine maximum sequence lengths.
5656
5757
Args:
@@ -86,6 +86,27 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
8686
self._max_middle_len = max(1, max_middle_len)
8787
self._max_inner_len = max(1, max_inner_len)
8888

89+
def remove(self, vocabularies: set[str]):
90+
"""Remove specified vocabularies from the processor."""
91+
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
92+
self.code_vocab = {"<pad>": 0, "<unk>": -1}
93+
for i, v in enumerate(vocab):
94+
self.code_vocab[v] = i + 1
95+
96+
def retain(self, vocabularies: set[str]):
97+
"""Retain only the specified vocabularies in the processor."""
98+
vocab = list(set(self.code_vocab.keys()) & vocabularies)
99+
self.code_vocab = {"<pad>": 0, "<unk>": -1}
100+
for i, v in enumerate(vocab):
101+
self.code_vocab[v] = i + 1
102+
103+
def add(self, vocabularies: set[str]):
104+
"""Add specified vocabularies to the processor."""
105+
vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"})
106+
self.code_vocab = {"<pad>": 0, "<unk>": -1}
107+
for i, v in enumerate(vocab):
108+
self.code_vocab[v] = i + 1
109+
89110
def process(self, value: List[List[List[Any]]]) -> torch.Tensor:
90111
"""Process deep nested sequence into padded 3D tensor.
91112
@@ -209,7 +230,7 @@ def __init__(self, forward_fill: bool = True):
209230
self._max_inner_len = 1 # Maximum length of inner sequences (values per visit)
210231
self.forward_fill = forward_fill
211232

212-
def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
233+
def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
213234
"""Determine maximum sequence lengths.
214235
215236
Args:

pyhealth/processors/nested_sequence_processor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import torch
44

55
from . import register_processor
6-
from .base_processor import FeatureProcessor
6+
from .base_processor import FeatureProcessor, VocabMixin
77

88

99
@register_processor("nested_sequence")
10-
class NestedSequenceProcessor(FeatureProcessor):
10+
class NestedSequenceProcessor(FeatureProcessor, VocabMixin):
1111
"""
1212
Feature processor for nested categorical sequences with vocabulary.
1313
@@ -86,6 +86,24 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
8686
# (-1 because <unk> is already in vocab)
8787
self.code_vocab["<unk>"] = len(self.code_vocab) - 1
8888

89+
def remove(self, vocabularies: set[str]):
90+
"""Remove specified vocabularies from the processor."""
91+
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
92+
vocab = ["<pad>"] + vocab + ["<unk>"]
93+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
94+
95+
def retain(self, vocabularies: set[str]):
96+
"""Retain only the specified vocabularies in the processor."""
97+
vocab = list(set(self.code_vocab.keys()) & vocabularies)
98+
vocab = ["<pad>"] + vocab + ["<unk>"]
99+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
100+
101+
def add(self, vocabularies: set[str]):
102+
"""Add specified vocabularies to the processor."""
103+
vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"})
104+
vocab = ["<pad>"] + vocab + ["<unk>"]
105+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
106+
89107
def process(self, value: List[List[Any]]) -> torch.Tensor:
90108
"""Process nested sequence into padded 2D tensor.
91109

pyhealth/processors/sequence_processor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import torch
44

55
from . import register_processor
6-
from .base_processor import FeatureProcessor
6+
from .base_processor import FeatureProcessor, VocabMixin
77

88

99
@register_processor("sequence")
10-
class SequenceProcessor(FeatureProcessor):
10+
class SequenceProcessor(FeatureProcessor, VocabMixin):
1111
"""
1212
Feature processor for encoding categorical sequences (e.g., medical codes) into numerical indices.
1313
@@ -48,6 +48,24 @@ def process(self, value: Any) -> torch.Tensor:
4848
indices.append(self.code_vocab["<unk>"])
4949

5050
return torch.tensor(indices, dtype=torch.long)
51+
52+
def remove(self, vocabularies: set[str]):
53+
"""Remove specified vocabularies from the processor."""
54+
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
55+
vocab = ["<pad>"] + vocab + ["<unk>"]
56+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
57+
58+
def retain(self, vocabularies: set[str]):
59+
"""Retain only the specified vocabularies in the processor."""
60+
vocab = list(set(self.code_vocab.keys()) & vocabularies)
61+
vocab = ["<pad>"] + vocab + ["<unk>"]
62+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
63+
64+
def add(self, vocabularies: set[str]):
65+
"""Add specified vocabularies to the processor."""
66+
vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"})
67+
vocab = ["<pad>"] + vocab + ["<unk>"]
68+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
5169

5270
def size(self):
5371
return len(self.code_vocab)

pyhealth/processors/stagenet_processor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import torch
44

55
from . import register_processor
6-
from .base_processor import FeatureProcessor
6+
from .base_processor import FeatureProcessor, VocabMixin
77

88

99
@register_processor("stagenet")
10-
class StageNetProcessor(FeatureProcessor):
10+
class StageNetProcessor(FeatureProcessor, VocabMixin):
1111
"""
1212
Feature processor for StageNet CODE inputs with coupled value/time data.
1313
@@ -122,6 +122,24 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
122122
# Since <unk> is already in the vocab dict, we use _next_index
123123
self.code_vocab["<unk>"] = self._next_index
124124

125+
def remove(self, vocabularies: set[str]):
126+
"""Remove specified vocabularies from the processor."""
127+
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
128+
vocab = ["<pad>"] + vocab + ["<unk>"]
129+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
130+
131+
def retain(self, vocabularies: set[str]):
132+
"""Retain only the specified vocabularies in the processor."""
133+
vocab = list(set(self.code_vocab.keys()) & vocabularies)
134+
vocab = ["<pad>"] + vocab + ["<unk>"]
135+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
136+
137+
def add(self, vocabularies: set[str]):
138+
"""Add specified vocabularies to the processor."""
139+
vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"})
140+
vocab = ["<pad>"] + vocab + ["<unk>"]
141+
self.code_vocab = {v: i for i, v in enumerate(vocab)}
142+
125143
def process(
126144
self, value: Tuple[Optional[List], List]
127145
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:

0 commit comments

Comments
 (0)