Skip to content

Commit e1dd0b1

Browse files
committed
[wip] coref processing
1 parent 69ed9db commit e1dd0b1

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

stanza/models/coref/model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,23 +379,36 @@ def infer(self, raw_words, sent_ids) -> CorefResult:
379379
cased_words_new = []
380380
sent_id_new = []
381381

382+
# word to id map
383+
# because inserted new nodes bump ids forward
384+
coref_id_to_real_id_map = {}
385+
382386
for indx,(i,j) in enumerate(zip(raw_words,
383387
((zeros_preds[word_start] > 0.5)
384388
.squeeze(-1)
385389
.tolist()))):
386390
if j:
391+
coref_id_to_real_id_map[len(cased_words_new)] = (indx-1) + 0.5
392+
coref_id_to_real_id_map[len(cased_words_new)+1] = indx
387393
cased_words_new.extend(["_", i])
388394
sent_id_new.extend([sent_ids[indx]]*2)
389395
else:
396+
coref_id_to_real_id_map[len(cased_words_new)] = indx
390397
cased_words_new.append(i)
391398
sent_id_new.append(sent_ids[indx])
392399

393-
return self.run(self.build_doc({
400+
results = self.run(self.build_doc({
394401
"document_id": "wb_doc_1",
395402
"cased_words": cased_words_new,
396403
"sent_id": sent_id_new
397404
}))
398405

406+
return {
407+
"result": results,
408+
"id_mapping": coref_id_to_real_id_map
409+
}
410+
411+
399412

400413
def run(self, # pylint: disable=too-many-locals
401414
doc: Doc,

stanza/pipeline/coref_processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def process(self, document):
9393
word_pos.append(word_idx)
9494

9595
results = self._model.infer(cased_words, sent_ids)
96+
id_mapping = results["id_mapping"]
97+
results = results["result"]
98+
9699
clusters = []
97100
for span_cluster in results.span_clusters:
98101
if len(span_cluster) == 0:
@@ -137,12 +140,13 @@ def process(self, document):
137140
sent_id = sent_ids[span[0]]
138141
start_word = word_pos[span[0]]
139142
end_word = word_pos[span[1]-1] + 1
140-
mentions.append(CorefMention(sent_id, start_word, end_word))
143+
mentions.append(CorefMention(sent_id, id_mapping[start_word], id_mapping[end_word]))
141144
representative = mentions[best_span]
142145
representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word)
143146

144147
chain = CorefChain(len(clusters), mentions, representative_text, best_span)
145148
clusters.append(chain)
146149

150+
breakpoint()
147151
document.coref = clusters
148152
return document

0 commit comments

Comments
 (0)