Skip to content

Commit 03dadc3

Browse files
committed
Put all the VOCAB_PREFIX_SIZE fiddling in the same place
1 parent 30d0d22 commit 03dadc3

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

stanza/models/depparse/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def pack(x):
283283
else:
284284
loss = 0
285285
preds.append(F.log_softmax(unlabeled_scores, 2).detach().cpu().numpy())
286-
preds.append(deprel_scores.max(3)[1].detach().cpu().numpy())
286+
deprels = deprel_scores.max(3)[1].detach().cpu().numpy() + VOCAB_PREFIX_SIZE
287+
preds.append(deprels)
287288

288289
return loss, preds

stanza/models/depparse/trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
1919
from stanza.models.common.chuliu_edmonds import chuliu_edmonds_one_root
2020
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
21-
from stanza.models.common.vocab import VOCAB_PREFIX_SIZE
2221
from stanza.models.depparse.model import Parser
2322
from stanza.models.pos.vocab import MultiVocab
2423

@@ -150,7 +149,6 @@ def predict(self, batch, unsort=True):
150149
self.model.eval()
151150
batch_size = word.size(0)
152151
_, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)
153-
preds[1] += VOCAB_PREFIX_SIZE
154152
head_seqs = [chuliu_edmonds_one_root(adj[:l, :l])[1:] for adj, l in zip(preds[0], sentlens)] # remove attachment for the root
155153
deprel_seqs = [self.vocab['deprel'].unmap([preds[1][i][j+1][h] for j, h in enumerate(hs)]) for i, hs in enumerate(head_seqs)]
156154

0 commit comments

Comments
 (0)