Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions stanza/models/common/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,22 @@ def coref(self, chains):

def _attach_coref_mentions(self, chains):
for sentence in self.sentences:
for word in sentence.words:
for word in sentence.all_words:
word.coref_chains = []

for chain in chains:
for mention_idx, mention in enumerate(chain.mentions):
sentence = self.sentences[mention.sentence]
for word_idx in range(mention.start_word, mention.end_word):
is_start = word_idx == mention.start_word
is_end = word_idx == mention.end_word - 1
is_representative = mention_idx == chain.representative_index
attachment = CorefAttachment(chain, is_start, is_end, is_representative)
sentence.words[word_idx].coref_chains.append(attachment)
if isinstance(mention.start_word, tuple):
attachment = CorefAttachment(chain, True, True, False)
sentence._empty_words[mention.start_word[1]-1].coref_chains.append(attachment)
else:
for word_idx in range(mention.start_word, mention.end_word):
is_start = word_idx == mention.start_word
is_end = word_idx == mention.end_word - 1
is_representative = mention_idx == chain.representative_index
attachment = CorefAttachment(chain, is_start, is_end, is_representative)
sentence.words[word_idx].coref_chains.append(attachment)

def reindex_sentences(self, start_index):
for sent_id, sentence in zip(range(start_index, start_index + len(self.sentences)), self.sentences):
Expand Down Expand Up @@ -737,6 +741,17 @@ def empty_words(self, value):
""" Set the list of words for this sentence. """
self._empty_words = value

@property
def all_words(self):
""" Access the list of words + empty words for this sentence. """
words = self._words
empty_words = self._empty_words

all_words = sorted(words + empty_words,
key=lambda x:(x.id,) if isinstance(x.id, int) else x.id)

return all_words

@property
def ents(self):
""" Access the list of entities in this sentence. """
Expand Down
2 changes: 2 additions & 0 deletions stanza/models/coref/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ class CorefResult:
rough_scores: torch.Tensor = None # [n_words, n_words]
span_scores: torch.Tensor = None # [n_heads, n_words, 2]
span_y: Tuple[torch.Tensor, torch.Tensor] = None # [n_heads] x2

zero_scores: torch.Tensor = None
2 changes: 2 additions & 0 deletions stanza/models/coref/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def __init__(self, path, config, tokenizer):
word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
subwords.extend(tokenized_word)
word_id.extend([i] * len(tokenized_word))

doc["word2subword"] = word2subword
doc["subwords"] = subwords
doc["word_id"] = word_id

self.__out.append(doc)
logger.info("Loaded %d docs from %s.", len(data_f), path)

Expand Down
71 changes: 58 additions & 13 deletions stanza/models/coref/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from stanza.models.coref.rough_scorer import RoughScorer
from stanza.models.coref.span_predictor import SpanPredictor
from stanza.models.coref.utils import GraphNode
from stanza.models.coref.utils import sigmoid_focal_loss
from stanza.models.coref.word_encoder import WordEncoder
from stanza.models.coref.dataset import CorefDataset
from stanza.models.coref.tokenizer_customization import *
Expand All @@ -41,6 +42,8 @@
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper

import torch.nn as nn

logger = logging.getLogger('stanza')

class CorefModel: # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -140,6 +143,8 @@ def evaluate(self,
running_loss = 0.0
s_correct = 0
s_total = 0
z_correct = 0
z_total = 0

with conll.open_(self.config, self.epochs_trained, data_split) \
as (gold_f, pred_f):
Expand All @@ -150,13 +155,21 @@ def evaluate(self,
# want to test evaluation on one language
continue

res = self.run(doc)
res = self.run(doc, True)
# measure zero prediction accuracy
zero_targets = torch.tensor(doc["is_zero"], device=res.zero_scores.device)
zero_preds = (res.zero_scores > 0).view(-1).to(zero_targets.dtype)
z_correct += (zero_preds == zero_targets).sum().item()
z_total += zero_targets.numel()

if (res.coref_y.argmax(dim=1) == 1).all():
logger.warning(f"EVAL: skipping document with no corefs...")
continue

running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item()
if res.word_clusters is None or res.span_clusters is None:
logger.warning(f"EVAL: skipping document with no clusters...")
continue

if res.span_y:
pred_starts = res.span_scores[:, :, 0].argmax(dim=1)
Expand Down Expand Up @@ -191,8 +204,10 @@ def evaluate(self,
f" f1: {s_lea[0]:.5f},"
f" p: {s_lea[1]:.5f},"
f" r: {s_lea[2]:<.5f}"
f" | ZA: {z_correct / z_total:<.5f}"
)
logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}")
logger.info(f"Zero prediction accuracy: {z_correct / z_total:.5f}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, is this always on? i would think there will be datasets that don't have zeros

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, reporting this shouldn't hurt, since all we'll have in that case is that all of doc["is_zero"] is False. Hence, that will give us 100% zeros accuracy, and not break any logging. Do you think we should handle those cases differently? The tricky part is that we have currently no way to tell if a dataset has no zeros, or if a batch has no zeros (which is quite likely since zeros are relatively rare.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case it doesn't matter too much, although i would think a higher level part of the routine could also look at the whole dataset and check if it has zeros or not. but not a big deal

Copy link
Member Author

@Jemoka Jemoka Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good; I would err on the side of "no" just because technically having "100% zeros accuracy" is technically correct still + involves less post-processing. Your call though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, no strong opinions except that / z_total is probably not ideal in the case of z_total == 0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have:

                zero_targets = torch.tensor(doc["is_zero"], device=res.zero_scores.device)
                z_total += zero_targets.numel()

so in this case the only situation in which z_total would be the case where the number of elements in doc["is_zero"] is zero for the entire corpus (i.e., the corpus has no length); this would be a bad state and not usually possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, will z_correct include documents correctly predicted to have 0 zeros?


return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff)

Expand Down Expand Up @@ -332,6 +347,7 @@ def load_model(path: str,

def run(self, # pylint: disable=too-many-locals
doc: Doc,
use_gold_spans_for_zeros = False
) -> CorefResult:
"""
This is a massive method, but it made sense to me to not split it into
Expand Down Expand Up @@ -380,16 +396,27 @@ def run(self, # pylint: disable=too-many-locals
res.coref_y = self._get_ground_truth(
cluster_ids, top_indices, (top_rough_scores > float("-inf")),
self.config.clusters_starts_are_singletons,
self.config.singletons)
self.config.singletons
)

res.word_clusters = self._clusterize(doc, res.coref_scores, top_indices,
self.config.singletons)
res.word_clusters = self._clusterize(
doc, res.coref_scores, top_indices,
self.config.singletons
)

res.span_scores, res.span_y = self.sp.get_training_data(doc, words)

if not self.training:
res.span_clusters = self.sp.predict(doc, words, res.word_clusters)

if not self.training and not use_gold_spans_for_zeros:
zero_words = words[[word_id
for cluster in res.word_clusters
for word_id in cluster]]
else:
zero_words = words[[i[0] for i in sorted(doc["head2span"])]]
res.zero_scores = self.zeros_predictor(zero_words)

return res

def save_weights(self, save_path=None, save_optimizers=True):
Expand Down Expand Up @@ -454,6 +481,7 @@ def train(self, log=False):
self.log_norms()
running_c_loss = 0.0
running_s_loss = 0.0
running_z_loss = 0.0
random.shuffle(docs_ids)
pbar = tqdm(docs_ids, unit="docs", ncols=0)
for doc_indx, doc_id in enumerate(pbar):
Expand All @@ -468,6 +496,14 @@ def train(self, log=False):

res = self.run(doc)

if res.zero_scores.size(0) == 0:
z_loss = 0.0 # since there are no corefs
else:
z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1),
(torch.tensor(doc["is_zero"])
.to(res.zero_scores.device).float()),
reduction="mean")

c_loss = self._coref_criterion(res.coref_scores, res.coref_y)

if res.span_y:
Expand All @@ -476,20 +512,24 @@ def train(self, log=False):
else:
s_loss = torch.zeros_like(c_loss)

del res

(c_loss + s_loss).backward()
(c_loss + s_loss + z_loss).backward()

running_c_loss += c_loss.item()
running_s_loss += s_loss.item()
if res.zero_scores.size(0) != 0:
running_z_loss += z_loss.item()

# log every 100 docs
if log and doc_indx % 100 == 0:
wandb.log({'train_c_loss': c_loss.item(),
'train_s_loss': s_loss.item()})
logged = {
'train_c_loss': c_loss.item(),
'train_s_loss': s_loss.item(),
}
if res.zero_scores.size(0) != 0:
logged['train_z_loss'] = z_loss.item()
wandb.log(logged)


del c_loss, s_loss
del c_loss, s_loss, z_loss, res

for optim in self.optimizers.values():
optim.step()
Expand All @@ -501,6 +541,7 @@ def train(self, log=False):
f" {doc['document_id']:26}"
f" c_loss: {running_c_loss / (pbar.n + 1):<.5f}"
f" s_loss: {running_s_loss / (pbar.n + 1):<.5f}"
f" z_loss: {running_z_loss / (pbar.n + 1):<.5f}"
)

self.epochs_trained += 1
Expand Down Expand Up @@ -614,12 +655,17 @@ def _build_model(self, foundation_cache):
self.we = WordEncoder(bert_emb, self.config).to(self.config.device)
self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device)
self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device)
self.zeros_predictor = nn.Sequential(
nn.Linear(bert_emb, bert_emb),
nn.ReLU(),
nn.Linear(bert_emb, 1)
).to(self.config.device)

self.trainable: Dict[str, torch.nn.Module] = {
"bert": self.bert, "we": self.we,
"rough_scorer": self.rough_scorer,
"pw": self.pw, "a_scorer": self.a_scorer,
"sp": self.sp
"sp": self.sp, "zeros_predictor": self.zeros_predictor
}

def _build_optimizers(self):
Expand Down Expand Up @@ -785,4 +831,3 @@ def _set_training(self, value: bool):
self._training = value
for module in self.trainable.values():
module.train(self._training)

55 changes: 55 additions & 0 deletions stanza/models/coref/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Set

import torch
import torch.nn.functional as F

from stanza.models.coref.const import EPSILON

Expand Down Expand Up @@ -33,3 +34,57 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False):
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
return torch.cat((dummy, tensor), dim=1)

def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "none",
) -> torch.Tensor:
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

Args:
inputs (Tensor): A float tensor of arbitrary shape.
The predictions for each example.
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha (float): Weighting factor in range [0, 1] to balance
positive vs negative examples or -1 for ignore. Default: ``0.25``.
gamma (float): Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples. Default: ``2``.
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``.
Returns:
Loss tensor with the reduction option applied.
"""
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py

if not (0 <= alpha <= 1) and alpha != -1:
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.")

p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)

if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss

# Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss
Loading