|
1 | | -from typing import Any, Dict, List |
| 1 | +from typing import Any, Dict, List, Iterable |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 |
|
5 | 5 | from . import register_processor |
6 | | -from .base_processor import FeatureProcessor |
| 6 | +from .base_processor import FeatureProcessor, VocabMixin |
7 | 7 |
|
8 | 8 |
|
9 | 9 | @register_processor("deep_nested_sequence") |
10 | | -class DeepNestedSequenceProcessor(FeatureProcessor): |
| 10 | +class DeepNestedSequenceProcessor(FeatureProcessor, VocabMixin): |
11 | 11 | """ |
12 | 12 | Feature processor for deeply nested categorical sequences with vocabulary. |
13 | 13 |
|
@@ -51,7 +51,7 @@ def __init__(self): |
51 | 51 | self._max_middle_len = 1 # Maximum length of middle sequences (e.g. visits) |
52 | 52 | self._max_inner_len = 1 # Maximum length of inner sequences (e.g. codes per visit) |
53 | 53 |
|
54 | | - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: |
| 54 | + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: |
55 | 55 | """Build vocabulary and determine maximum sequence lengths. |
56 | 56 |
|
57 | 57 | Args: |
@@ -86,6 +86,27 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None: |
86 | 86 | self._max_middle_len = max(1, max_middle_len) |
87 | 87 | self._max_inner_len = max(1, max_inner_len) |
88 | 88 |
|
| 89 | + def remove(self, vocabularies: set[str]): |
| 90 | + """Remove specified vocabularies from the processor.""" |
| 91 | + vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"}) |
| 92 | + self.code_vocab = {"<pad>": 0, "<unk>": -1} |
| 93 | + for i, v in enumerate(vocab): |
| 94 | + self.code_vocab[v] = i + 1 |
| 95 | + |
| 96 | + def retain(self, vocabularies: set[str]): |
| 97 | + """Retain only the specified vocabularies in the processor.""" |
| 98 | + vocab = list(set(self.code_vocab.keys()) & vocabularies) |
| 99 | + self.code_vocab = {"<pad>": 0, "<unk>": -1} |
| 100 | + for i, v in enumerate(vocab): |
| 101 | + self.code_vocab[v] = i + 1 |
| 102 | + |
| 103 | + def add(self, vocabularies: set[str]): |
| 104 | + """Add specified vocabularies to the processor.""" |
| 105 | + vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"}) |
| 106 | + self.code_vocab = {"<pad>": 0, "<unk>": -1} |
| 107 | + for i, v in enumerate(vocab): |
| 108 | + self.code_vocab[v] = i + 1 |
| 109 | + |
89 | 110 | def process(self, value: List[List[List[Any]]]) -> torch.Tensor: |
90 | 111 | """Process deep nested sequence into padded 3D tensor. |
91 | 112 |
|
@@ -209,7 +230,7 @@ def __init__(self, forward_fill: bool = True): |
209 | 230 | self._max_inner_len = 1 # Maximum length of inner sequences (values per visit) |
210 | 231 | self.forward_fill = forward_fill |
211 | 232 |
|
212 | | - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: |
| 233 | + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: |
213 | 234 | """Determine maximum sequence lengths. |
214 | 235 |
|
215 | 236 | Args: |
|
0 commit comments