Skip to content

Commit 69ed9db

Browse files
committed
reweight evaluatinos for coref
1 parent c4d5351 commit 69ed9db

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

stanza/models/coref/model.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,12 @@ def evaluate(self,
144144
s_correct = 0
145145
s_total = 0
146146

147-
z_correct = 0
147+
z_correct = 0 # z_tp
148148
z_total = 0
149149

150+
z_fp = 0
151+
z_fn = 0
152+
150153

151154
with conll.open_(self.config, self.epochs_trained, data_split) \
152155
as (gold_f, pred_f):
@@ -159,11 +162,16 @@ def evaluate(self,
159162
continue
160163

161164
res = self.run(doc, run_zeros=True)
162-
z_acc = ((res.zeros_scores[res.zeros_y != -100].reshape(-1) > 0.5) ==
163-
(res.zeros_y[res.zeros_y != -100].reshape(-1) == 1))
165+
166+
z_preds = (res.zeros_scores[res.zeros_y != -100].reshape(-1) > 0.5)
167+
z_targets = (res.zeros_y[res.zeros_y != -100].reshape(-1) == 1)
168+
169+
z_acc = (z_preds == z_targets)
164170
z_total += z_acc.size(-1)
165171
z_correct += z_acc.sum().item()
166-
172+
173+
z_fp += (z_preds & (~z_targets)).sum().item()
174+
z_fn += ((~z_preds) & z_targets).sum().item()
167175

168176
if (res.coref_y.argmax(dim=1) == 1).all():
169177
logger.warning(f"EVAL: skipping document with no corefs...")
@@ -205,12 +213,15 @@ def evaluate(self,
205213
f" p: {s_lea[1]:.5f},"
206214
f" r: {s_lea[2]:<.5f}"
207215
f" | Z: "
208-
f" acc: {(z_correct / z_total):<.5f}"
216+
f" acc: {(z_correct / z_total):<.5f},"
217+
f" p: {(z_correct / (z_correct + z_fp)):<.5f},"
218+
f" r: {(z_correct / (z_correct + z_fn)):<.5f}"
209219
)
210220
logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}")
211221

212222
return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc,
213-
w_checker.bakeoff, s_checker.bakeoff, z_correct / z_total)
223+
w_checker.bakeoff, s_checker.bakeoff, z_correct / z_total, (z_correct / (z_correct + z_fp)),
224+
(z_correct / (z_correct + z_fn)))
214225

215226
def load_weights(self,
216227
path: Optional[str] = None,
@@ -545,7 +556,18 @@ def train(self, log=False):
545556
if (res.zeros_y == 1).any():
546557
zeros_preds = res.zeros_scores[res.zeros_y != -100].reshape(-1)
547558
labels = res.zeros_y[res.zeros_y != -100].reshape(-1)
548-
zeros_loss = F.binary_cross_entropy(zeros_preds, labels.float())
559+
# reweight such that the zeros and nonzeros count for equal weighting
560+
# that is, artifically balance the "number of samples" by weighting between
561+
# them equally
562+
563+
weight_each_zero = 0.5/labels.sum()
564+
weight_each_nonzero = 0.5/(labels.size(-1) - labels.sum())
565+
566+
weights = torch.empty_like(labels).float()
567+
weights[labels.bool()] = weight_each_zero
568+
weights[~labels.bool()] = weight_each_nonzero
569+
570+
zeros_loss = F.binary_cross_entropy(zeros_preds, labels.float(), weight=weights)
549571
else:
550572
zeros_loss = 0.0 # don't apply loss if there's nothing to learn
551573

@@ -595,7 +617,9 @@ def train(self, log=False):
595617
if log:
596618
wandb.log({'dev_score': scores[1]})
597619
wandb.log({'dev_bakeoff': scores[-2]})
598-
wandb.log({'dev_zeros_acc': scores[-1]})
620+
wandb.log({'dev_zeros_acc': scores[-3],
621+
'dev_zeros_p': scores[-2],
622+
'dev_zeros_r': scores[-1]})
599623

600624
if best_f1 is None or scores[1] > best_f1:
601625

stanza/pipeline/coref_processor.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,7 @@ def process(self, document):
9292
sent_ids.append(sent_idx)
9393
word_pos.append(word_idx)
9494

95-
coref_input = {
96-
"document_id": "wb_doc_1",
97-
"cased_words": cased_words,
98-
"sent_id": sent_ids
99-
}
100-
results = self._model.infer(coref_input)
95+
results = self._model.infer(cased_words, sent_ids)
10196
clusters = []
10297
for span_cluster in results.span_clusters:
10398
if len(span_cluster) == 0:

0 commit comments

Comments
 (0)