Skip to content

Commit a334cf9

Browse files
committed
reshape into 2 class preds?
1 parent 61bc248 commit a334cf9

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

stanza/models/coref/model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def infer(self, raw_words, sent_ids) -> CorefResult:
387387
coref_id_to_real_id_map = {}
388388

389389
for indx,(i,j) in enumerate(zip(raw_words,
390-
((zeros_preds[word_start] > 0.5)
390+
((zeros_preds[word_start].argmax(dim=-1) > 0)
391391
.squeeze(-1)
392392
.tolist()))):
393393
if j:
@@ -575,8 +575,8 @@ def train(self, log=False):
575575
c_loss = self._coref_criterion(res.coref_scores, res.coref_y)
576576

577577
if (res.zeros_y == 1).any():
578-
zeros_preds = res.zeros_scores[res.zeros_y != -100].reshape(1, -1)
579-
labels = res.zeros_y[res.zeros_y != -100].reshape(1, -1)
578+
zeros_preds = res.zeros_scores[res.zeros_y != -100].reshape(-1, 2)
579+
labels = res.zeros_y[res.zeros_y != -100].reshape(-1, 2)
580580
# reweight such that the zeros and nonzeros count for equal weighting
581581
# that is, artifically balance the "number of samples" by weighting between
582582
# them equally
@@ -668,8 +668,8 @@ def train(self, log=False):
668668
def _bertify(self, doc: Doc, return_subwords=False) -> torch.Tensor:
669669
if return_subwords:
670670
(nonblank_batches,
671-
nonblank_labels) = bert.get_subwords_batches(doc, self.config,
672-
self.tokenizer, nonblank_only=True)
671+
nonblank_labels) = bert.get_subwords_batches(doc, self.config,
672+
self.tokenizer, nonblank_only=True)
673673
all_batches = bert.get_subwords_batches(doc, self.config, self.tokenizer)
674674

675675
# we index the batches n at a time to prevent oom
@@ -768,8 +768,9 @@ def _build_model(self, foundation_cache):
768768
self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device)
769769
self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device)
770770
self.zeros_predictor = nn.Sequential(
771-
nn.Linear(self.bert.config.hidden_size, 1),
772-
nn.Sigmoid()
771+
nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),
772+
nn.ReLU(),
773+
nn.Linear(self.bert.config.hidden_size, 2),
773774
).to(self.config.device)
774775

775776
self.trainable: Dict[str, torch.nn.Module] = {

0 commit comments

Comments
 (0)