@@ -387,7 +387,7 @@ def infer(self, raw_words, sent_ids) -> CorefResult:
387
387
coref_id_to_real_id_map = {}
388
388
389
389
for indx ,(i ,j ) in enumerate (zip (raw_words ,
390
- ((zeros_preds [word_start ] > 0.5 )
390
+ ((zeros_preds [word_start ]. argmax ( dim = - 1 ) > 0 )
391
391
.squeeze (- 1 )
392
392
.tolist ()))):
393
393
if j :
@@ -575,8 +575,8 @@ def train(self, log=False):
575
575
c_loss = self ._coref_criterion (res .coref_scores , res .coref_y )
576
576
577
577
if (res .zeros_y == 1 ).any ():
578
- zeros_preds = res .zeros_scores [res .zeros_y != - 100 ].reshape (1 , - 1 )
579
- labels = res .zeros_y [res .zeros_y != - 100 ].reshape (1 , - 1 )
578
+ zeros_preds = res .zeros_scores [res .zeros_y != - 100 ].reshape (- 1 , 2 )
579
+ labels = res .zeros_y [res .zeros_y != - 100 ].reshape (- 1 , 2 )
580
580
# reweight such that the zeros and nonzeros count for equal weighting
581
581
# that is, artifically balance the "number of samples" by weighting between
582
582
# them equally
@@ -668,8 +668,8 @@ def train(self, log=False):
668
668
def _bertify (self , doc : Doc , return_subwords = False ) -> torch .Tensor :
669
669
if return_subwords :
670
670
(nonblank_batches ,
671
- nonblank_labels ) = bert .get_subwords_batches (doc , self .config ,
672
- self .tokenizer , nonblank_only = True )
671
+ nonblank_labels ) = bert .get_subwords_batches (doc , self .config ,
672
+ self .tokenizer , nonblank_only = True )
673
673
all_batches = bert .get_subwords_batches (doc , self .config , self .tokenizer )
674
674
675
675
# we index the batches n at a time to prevent oom
@@ -768,8 +768,9 @@ def _build_model(self, foundation_cache):
768
768
self .rough_scorer = RoughScorer (bert_emb , self .config ).to (self .config .device )
769
769
self .sp = SpanPredictor (bert_emb , self .config .sp_embedding_size ).to (self .config .device )
770
770
self .zeros_predictor = nn .Sequential (
771
- nn .Linear (self .bert .config .hidden_size , 1 ),
772
- nn .Sigmoid ()
771
+ nn .Linear (self .bert .config .hidden_size , self .bert .config .hidden_size ),
772
+ nn .ReLU (),
773
+ nn .Linear (self .bert .config .hidden_size , 2 ),
773
774
).to (self .config .device )
774
775
775
776
self .trainable : Dict [str , torch .nn .Module ] = {
0 commit comments