@@ -144,9 +144,12 @@ def evaluate(self,
144
144
s_correct = 0
145
145
s_total = 0
146
146
147
- z_correct = 0
147
+ z_correct = 0 # z_tp
148
148
z_total = 0
149
149
150
+ z_fp = 0
151
+ z_fn = 0
152
+
150
153
151
154
with conll .open_ (self .config , self .epochs_trained , data_split ) \
152
155
as (gold_f , pred_f ):
@@ -159,11 +162,16 @@ def evaluate(self,
159
162
continue
160
163
161
164
res = self .run (doc , run_zeros = True )
162
- z_acc = ((res .zeros_scores [res .zeros_y != - 100 ].reshape (- 1 ) > 0.5 ) ==
163
- (res .zeros_y [res .zeros_y != - 100 ].reshape (- 1 ) == 1 ))
165
+
166
+ z_preds = (res .zeros_scores [res .zeros_y != - 100 ].reshape (- 1 ) > 0.5 )
167
+ z_targets = (res .zeros_y [res .zeros_y != - 100 ].reshape (- 1 ) == 1 )
168
+
169
+ z_acc = (z_preds == z_targets )
164
170
z_total += z_acc .size (- 1 )
165
171
z_correct += z_acc .sum ().item ()
166
-
172
+
173
+ z_fp += (z_preds & (~ z_targets )).sum ().item ()
174
+ z_fn += ((~ z_preds ) & z_targets ).sum ().item ()
167
175
168
176
if (res .coref_y .argmax (dim = 1 ) == 1 ).all ():
169
177
logger .warning (f"EVAL: skipping document with no corefs..." )
@@ -205,12 +213,15 @@ def evaluate(self,
205
213
f" p: { s_lea [1 ]:.5f} ,"
206
214
f" r: { s_lea [2 ]:<.5f} "
207
215
f" | Z: "
208
- f" acc: { (z_correct / z_total ):<.5f} "
216
+ f" acc: { (z_correct / z_total ):<.5f} ,"
217
+ f" p: { (z_correct / (z_correct + z_fp )):<.5f} ,"
218
+ f" r: { (z_correct / (z_correct + z_fn )):<.5f} "
209
219
)
210
220
logger .info (f"CoNLL-2012 3-Score Average : { w_checker .bakeoff :.5f} " )
211
221
212
222
return (running_loss / len (docs ), * s_checker .total_lea , * w_checker .total_lea , * s_checker .mbc , * w_checker .mbc ,
213
- w_checker .bakeoff , s_checker .bakeoff , z_correct / z_total )
223
+ w_checker .bakeoff , s_checker .bakeoff , z_correct / z_total , (z_correct / (z_correct + z_fp )),
224
+ (z_correct / (z_correct + z_fn )))
214
225
215
226
def load_weights (self ,
216
227
path : Optional [str ] = None ,
@@ -545,7 +556,18 @@ def train(self, log=False):
545
556
if (res .zeros_y == 1 ).any ():
546
557
zeros_preds = res .zeros_scores [res .zeros_y != - 100 ].reshape (- 1 )
547
558
labels = res .zeros_y [res .zeros_y != - 100 ].reshape (- 1 )
548
- zeros_loss = F .binary_cross_entropy (zeros_preds , labels .float ())
559
+ # reweight such that the zeros and nonzeros count for equal weighting
560
+ # that is, artifically balance the "number of samples" by weighting between
561
+ # them equally
562
+
563
+ weight_each_zero = 0.5 / labels .sum ()
564
+ weight_each_nonzero = 0.5 / (labels .size (- 1 ) - labels .sum ())
565
+
566
+ weights = torch .empty_like (labels ).float ()
567
+ weights [labels .bool ()] = weight_each_zero
568
+ weights [~ labels .bool ()] = weight_each_nonzero
569
+
570
+ zeros_loss = F .binary_cross_entropy (zeros_preds , labels .float (), weight = weights )
549
571
else :
550
572
zeros_loss = 0.0 # don't apply loss if there's nothing to learn
551
573
@@ -595,7 +617,9 @@ def train(self, log=False):
595
617
if log :
596
618
wandb .log ({'dev_score' : scores [1 ]})
597
619
wandb .log ({'dev_bakeoff' : scores [- 2 ]})
598
- wandb .log ({'dev_zeros_acc' : scores [- 1 ]})
620
+ wandb .log ({'dev_zeros_acc' : scores [- 3 ],
621
+ 'dev_zeros_p' : scores [- 2 ],
622
+ 'dev_zeros_r' : scores [- 1 ]})
599
623
600
624
if best_f1 is None or scores [1 ] > best_f1 :
601
625
0 commit comments