Skip to content

Commit 284e9b4

Browse files
committed
Temporarily patch an issue with the dependency parser so that it doesn't produce <PAD> as a relation
1 parent bfbd9a1 commit 284e9b4

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

stanza/models/depparse/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
1919
from stanza.models.common.chuliu_edmonds import chuliu_edmonds_one_root
2020
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
21+
from stanza.models.common.vocab import VOCAB_PREFIX_SIZE
2122
from stanza.models.depparse.model import Parser
2223
from stanza.models.pos.vocab import MultiVocab
2324

@@ -73,6 +74,8 @@ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None,
7374
self.model = self.model.to(device)
7475
self.__init_optim()
7576

77+
self.fallback = self.vocab['deprel'].unit2id('dep') if 'dep' in self.vocab['deprel'] else None
78+
7679
if ignore_model_config:
7780
self.args = orig_args
7881

@@ -147,6 +150,9 @@ def predict(self, batch, unsort=True):
147150
self.model.eval()
148151
batch_size = word.size(0)
149152
_, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)
153+
# TODO: would be cleaner for the model to not have the capability to produce predictions < VOCAB_PREFIX_SIZE
154+
if self.fallback is not None:
155+
preds[1][preds[1] < VOCAB_PREFIX_SIZE] = self.fallback
150156
head_seqs = [chuliu_edmonds_one_root(adj[:l, :l])[1:] for adj, l in zip(preds[0], sentlens)] # remove attachment for the root
151157
deprel_seqs = [self.vocab['deprel'].unmap([preds[1][i][j+1][h] for j, h in enumerate(hs)]) for i, hs in enumerate(head_seqs)]
152158

0 commit comments

Comments
 (0)