Skip to content

Commit e22685f

Browse files
committed
fixes for zero coref inference
1 parent 06c3da4 commit e22685f

File tree

2 files changed

+77
-39
lines changed

2 files changed

+77
-39
lines changed

stanza/models/common/doc.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -483,18 +483,22 @@ def coref(self, chains):
483483

484484
def _attach_coref_mentions(self, chains):
485485
for sentence in self.sentences:
486-
for word in sentence.words:
486+
for word in sentence.all_words:
487487
word.coref_chains = []
488488

489489
for chain in chains:
490490
for mention_idx, mention in enumerate(chain.mentions):
491491
sentence = self.sentences[mention.sentence]
492-
for word_idx in range(mention.start_word, mention.end_word):
493-
is_start = word_idx == mention.start_word
494-
is_end = word_idx == mention.end_word - 1
495-
is_representative = mention_idx == chain.representative_index
496-
attachment = CorefAttachment(chain, is_start, is_end, is_representative)
497-
sentence.words[word_idx].coref_chains.append(attachment)
492+
if isinstance(mention.start_word, tuple):
493+
attachment = CorefAttachment(chain, True, True, False)
494+
sentence._empty_words[mention.start_word[1]-1].coref_chains.append(attachment)
495+
else:
496+
for word_idx in range(mention.start_word, mention.end_word):
497+
is_start = word_idx == mention.start_word
498+
is_end = word_idx == mention.end_word - 1
499+
is_representative = mention_idx == chain.representative_index
500+
attachment = CorefAttachment(chain, is_start, is_end, is_representative)
501+
sentence.words[word_idx].coref_chains.append(attachment)
498502

499503
def reindex_sentences(self, start_index):
500504
for sent_id, sentence in zip(range(start_index, start_index + len(self.sentences)), self.sentences):

stanza/pipeline/coref_processor.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from stanza.pipeline._constants import *
1010
from stanza.pipeline.processor import UDProcessor, register_processor
1111

12+
import torch
13+
1214
def extract_text(document, sent_id, start_word, end_word):
1315
sentence = document.sentences[sent_id]
1416
tokens = []
@@ -128,6 +130,11 @@ def process(self, document):
128130
best_span = None
129131
max_propn = 0
130132
for span_idx, span in enumerate(span_cluster):
133+
word_idx = results.word_clusters[cluster_idx][span_idx]
134+
is_zero = zero_nodes_created.get((cluster_idx, word_idx))
135+
if is_zero:
136+
continue
137+
131138
sent_id = sent_ids[span[0]]
132139
sentence = sentences[sent_id]
133140
start_word = word_pos[span[0]]
@@ -145,21 +152,33 @@ def process(self, document):
145152
max_propn = num_propn
146153

147154
mentions = []
148-
for span in span_cluster:
149-
sent_id = sent_ids[span[0]]
150-
start_word = word_pos[span[0]]
151-
end_word = word_pos[span[1]-1] + 1
152-
mentions.append(CorefMention(sent_id, start_word, end_word))
153-
154-
# Add zero node mentions to this cluster if any exist
155-
for zero_cluster_idx, zero_sent_id, zero_word_decimal_id in zero_nodes_created:
156-
if zero_cluster_idx == cluster_idx:
157-
# Zero node is a single "word" mention at the decimal position
158-
import math
159-
end_word = math.floor(zero_word_decimal_id) + 1
160-
mentions.append(CorefMention(zero_sent_id, zero_word_decimal_id, end_word))
161-
representative = mentions[best_span]
162-
representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word)
155+
for span_idx, span in enumerate(span_cluster):
156+
word_idx = results.word_clusters[cluster_idx][span_idx]
157+
is_zero = zero_nodes_created.get((cluster_idx, word_idx))
158+
if is_zero:
159+
(sent_id, zero_word_id) = is_zero
160+
# if the word id is a tuple, it will be attached
161+
# to the zero
162+
mentions.append(
163+
CorefMention(
164+
sent_id,
165+
zero_word_id,
166+
zero_word_id
167+
)
168+
)
169+
else:
170+
sent_id = sent_ids[span[0]]
171+
start_word = word_pos[span[0]]
172+
end_word = word_pos[span[1]-1] + 1
173+
mentions.append(CorefMention(sent_id, start_word, end_word))
174+
175+
# if we ended up with no best span, then our "representative text"
176+
# is just underscore
177+
if best_span:
178+
representative = mentions[best_span]
179+
representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word)
180+
else:
181+
representative_text = "_"
163182

164183
chain = CorefChain(len(clusters), mentions, representative_text, best_span)
165184
clusters.append(chain)
@@ -173,15 +192,26 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):
173192
return
174193

175194
zero_scores = results.zero_scores.squeeze(-1) if results.zero_scores.dim() > 1 else results.zero_scores
195+
is_zero = []
176196

177197
# Flatten word_clusters to get the word indices that correspond to zero_scores
178198
cluster_word_ids = []
179-
for cluster in results.word_clusters:
199+
cluster_mapping = {}
200+
counter = 0
201+
for indx, cluster in enumerate(results.word_clusters):
202+
for _ in range(len(cluster)):
203+
cluster_mapping[counter] = indx
204+
counter += 1
180205
cluster_word_ids.extend(cluster)
181206

182207
# Find indices where zero_scores > 0
183-
zero_indices = (zero_scores > 0).nonzero(as_tuple=True)[0]
184-
208+
print(zero_scores)
209+
zero_indices = (zero_scores > 0.0).nonzero()
210+
211+
# this dict maps (cluster_id, word_id) to (cluster_id, start, end)
212+
# which overrides span_clusters
213+
zero_to_coref = {}
214+
185215
for zero_idx in zero_indices:
186216
zero_idx = zero_idx.item()
187217
if zero_idx >= len(cluster_word_ids):
@@ -193,17 +223,21 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):
193223

194224
# Create zero node - attach BEFORE the current word
195225
# This means the zero node comes after word_id-1 but before word_id
196-
if word_id > 0:
197-
zero_word_id = (word_id, 1) # attach after word_id-1, before word_id
198-
zero_word = Word(document.sentences[sent_id], {
199-
"text": "_",
200-
"lemma": "_",
201-
"id": zero_word_id
202-
})
203-
document.sentences[sent_id]._empty_words.append(zero_word)
204-
205-
# Track this zero node for adding to coreference clusters
206-
cluster_idx, _ = cluster_mapping[zero_idx]
207-
zero_nodes_created.append((cluster_idx, sent_id, word_id + 0.1))
208-
209-
return zero_nodes_created
226+
zero_word_id = (
227+
word_id,
228+
len(document.sentences[sent_id]._empty_words)+1
229+
) # attach after word_id-1, before word_id
230+
zero_word = Word(document.sentences[sent_id], {
231+
"text": "_",
232+
"lemma": "_",
233+
"id": zero_word_id
234+
})
235+
document.sentences[sent_id]._empty_words.append(zero_word)
236+
237+
# Track this zero node for adding to coreference clusters
238+
cluster_idx = cluster_mapping[zero_idx]
239+
zero_to_coref[(cluster_idx, word_idx)] = (
240+
sent_id, zero_word_id
241+
)
242+
243+
return zero_to_coref

0 commit comments

Comments
 (0)