Skip to content

Commit 359a2e5

Browse files
committed
small debugging patches to support empty node prediction
1 parent 009c31c commit 359a2e5

File tree

5 files changed

+10
-28
lines changed

5 files changed

+10
-28
lines changed

stanza/models/common/doc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -747,10 +747,10 @@ def all_words(self):
747747
words = self._words
748748
empty_words = self._empty_words
749749

750-
all = sorted(words + empty_words, key=lambda x:(x.id,)
751-
if isinstance(x.id, int) else x.id)
750+
all_words = sorted(words + empty_words,
751+
key=lambda x:(x.id,) if isinstance(x.id, int) else x.id)
752752

753-
return all
753+
return all_words
754754

755755
@property
756756
def ents(self):

stanza/models/coref/dataset.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ def __init__(self, path, config, tokenizer):
3838
word2subword = []
3939
subwords = []
4040
word_id = []
41-
nonblank_subwords = [] # a list of subwords, skipping _
42-
previous_was_blank = [] # was the word before _?
43-
was_blank = False # a flag to set if we saw "_"
4441
for i, word in enumerate(doc["cased_words"]):
4542
tokenized = self.tokenizer.tokenize(word)
4643
if len(tokenized) == 0:
@@ -53,17 +50,6 @@ def __init__(self, path, config, tokenizer):
5350
word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
5451
subwords.extend(tokenized_word)
5552
word_id.extend([i] * len(tokenized_word))
56-
if word == "_":
57-
was_blank = True
58-
else:
59-
nonblank_subwords.extend(tokenized_word)
60-
previous_was_blank.extend(
61-
[True if was_blank else False]+[False]*(len(tokenized_word)-1)
62-
)
63-
was_blank = False
64-
65-
doc["nonblank_subwords"] = nonblank_subwords
66-
doc["blank_prefix"] = previous_was_blank
6753

6854
doc["word2subword"] = word2subword
6955
doc["subwords"] = subwords

stanza/models/coref/model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -512,13 +512,11 @@ def train(self, log=False):
512512
else:
513513
s_loss = torch.zeros_like(c_loss)
514514

515-
del res
516-
517515
(c_loss + s_loss + z_loss).backward()
518516

519517
running_c_loss += c_loss.item()
520518
running_s_loss += s_loss.item()
521-
if z_loss:
519+
if res.zero_scores.size(0) != 0:
522520
running_z_loss += z_loss.item()
523521

524522
# log every 100 docs
@@ -527,12 +525,11 @@ def train(self, log=False):
527525
'train_c_loss': c_loss.item(),
528526
'train_s_loss': s_loss.item(),
529527
}
530-
if z_loss:
528+
if res.zero_scores.size(0) != 0:
531529
logged['train_z_loss'] = z_loss.item()
532530
wandb.log(logged)
533531

534-
535-
del c_loss, s_loss, z_loss
532+
del c_loss, s_loss, z_loss, res
536533

537534
for optim in self.optimizers.values():
538535
optim.step()

stanza/pipeline/coref_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def process(self, document):
174174

175175
# if we ended up with no best span, then our "representative text"
176176
# is just underscore
177-
if best_span:
177+
if best_span is not None:
178178
representative = mentions[best_span]
179179
representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word)
180180
else:
@@ -205,7 +205,6 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):
205205
cluster_word_ids.extend(cluster)
206206

207207
# Find indices where zero_scores > 0
208-
print(zero_scores)
209208
zero_indices = (zero_scores > 0.0).nonzero()
210209

211210
# this dict maps (cluster_id, word_id) to (cluster_id, start, end)

stanza/utils/datasets/coref/convert_udcoref.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def process_documents(docs, augment=False):
129129
if sentence_text[span[1]] == "_" and span[1] == span[2]:
130130
is_zero.append([span[0], True])
131131
zero = True
132-
# oo! thaht's a zero coref, we should merge it forwards
132+
# oo! that's a zero coref, we should merge it forwards
133133
# i.e. we pick the next word as the head!
134134
span = [span[0], span[1]+1, span[2]+1]
135135
# crap! there's two zeros right next to each other
@@ -163,7 +163,7 @@ def process_documents(docs, augment=False):
163163
# words from 0, so we have to subtract 1 from the stanza heads
164164
#print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)
165165
# treat the head of the phrase as the first word that has a head outside the phrase
166-
if parsed_sentence.all_words[candidate_head].head and (
166+
if (parsed_sentence.all_words[candidate_head].head is not None) and (
167167
parsed_sentence.all_words[candidate_head].head - 1 < span[1] or
168168
parsed_sentence.all_words[candidate_head].head - 1 > span[2]
169169
):
@@ -205,7 +205,7 @@ def process_documents(docs, augment=False):
205205
[(old_to_new[start], old_to_new[end - 1] + 1) for start, end in cluster]
206206
for cluster in span_clusters
207207
]
208-
except:
208+
except (KeyError, TypeError) as _: # two errors, either end-1 = -1, or start/end is None
209209
warnings.warn("Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up.")
210210
continue
211211
word_clusters = [

0 commit comments

Comments
 (0)