|
9 | 9 | import transformers as tfs |
10 | 10 | from bratlib import data as brat_data |
11 | 11 | from bratlib.data.extensions.instance import ContigEntity |
12 | | -from bratlib.tools.validation import validate_bratfile |
13 | 12 | from spacy.tokens.span import Span |
14 | 13 |
|
15 | 14 |
|
@@ -119,6 +118,10 @@ def _pseudofy_side(self, rel: brat_data.Relation, sentence: Span, k: int, do_lef |
119 | 118 | token_tensor = torch.tensor(indexed_tokens) |
120 | 119 | mask_tensor = torch.tensor([token != '[MASK]' for token in tokenized_sentence], dtype=torch.float) |
121 | 120 |
|
| 121 | + if len(token_tensor) > 512: |
| 122 | + # This is the token limit we report on, but the limit depends on the BERT model |
| 123 | + return None |
| 124 | + |
122 | 125 | with torch.no_grad(): |
123 | 126 | result = self.bert(token_tensor.unsqueeze(0), mask_tensor.unsqueeze(0), masked_lm_labels=None) |
124 | 127 |
|
@@ -175,17 +178,9 @@ def pseudofy_file(self, ann: brat_data.BratFile) -> SentenceGenerator: |
175 | 178 | sentences = [find_sentence(arg, sentence_ranges) for arg in |
176 | 179 | (rel.arg1.start, rel.arg1.end, rel.arg2.start, rel.arg2.end)] |
177 | 180 |
|
178 | | - first = min(sentences, key=lambda x: x.start_char) |
179 | | - last = max(sentences, key=lambda x: x.end_char) |
180 | | - if first is last: |
181 | | - # The ideal case, both args are in the same sentence |
182 | | - text_span = first |
183 | | - elif first.end_char + 1 == last.start_char: |
184 | | - # The args are in two adjacent sentences |
185 | | - text_span = doc[first.start_char:last.end_char] |
186 | | - else: |
187 | | - # The args are more than two sentences apart; we will ignore these |
188 | | - continue |
| 181 | + first = min(sentences, key=lambda x: x.start_char).start_char |
| 182 | + last = max(sentences, key=lambda x: x.end_char).end_char |
| 183 | + text_span = doc[first:last] |
189 | 184 |
|
190 | 185 | yield from filter(self.filter, self.pseudofy_relation(rel, text_span)) |
191 | 186 |
|
@@ -221,10 +216,6 @@ def _pseudofy_file(self, ann: brat_data.BratFile, output_dir: Path) -> None: |
221 | 216 | with pseudo_ann.open('w+') as f: |
222 | 217 | f.write(str(new_ann)) |
223 | 218 |
|
224 | | - if __debug__: |
225 | | - ann = brat_data.BratFile.from_ann_path(pseudo_ann) |
226 | | - assert all(validate_bratfile(ann)) |
227 | | - |
228 | 219 | def pseudofy_dataset(self, dataset: brat_data.BratDataset, output_dir: Path) -> brat_data.BratDataset: |
229 | 220 | for ann in dataset: |
230 | 221 | self._pseudofy_file(ann, output_dir) |
|
0 commit comments