9
9
from stanza .pipeline ._constants import *
10
10
from stanza .pipeline .processor import UDProcessor , register_processor
11
11
12
+ import torch
13
+
12
14
def extract_text (document , sent_id , start_word , end_word ):
13
15
sentence = document .sentences [sent_id ]
14
16
tokens = []
@@ -128,6 +130,11 @@ def process(self, document):
128
130
best_span = None
129
131
max_propn = 0
130
132
for span_idx , span in enumerate (span_cluster ):
133
+ word_idx = results .word_clusters [cluster_idx ][span_idx ]
134
+ is_zero = zero_nodes_created .get ((cluster_idx , word_idx ))
135
+ if is_zero :
136
+ continue
137
+
131
138
sent_id = sent_ids [span [0 ]]
132
139
sentence = sentences [sent_id ]
133
140
start_word = word_pos [span [0 ]]
@@ -145,21 +152,33 @@ def process(self, document):
145
152
max_propn = num_propn
146
153
147
154
mentions = []
148
- for span in span_cluster :
149
- sent_id = sent_ids [span [0 ]]
150
- start_word = word_pos [span [0 ]]
151
- end_word = word_pos [span [1 ]- 1 ] + 1
152
- mentions .append (CorefMention (sent_id , start_word , end_word ))
153
-
154
- # Add zero node mentions to this cluster if any exist
155
- for zero_cluster_idx , zero_sent_id , zero_word_decimal_id in zero_nodes_created :
156
- if zero_cluster_idx == cluster_idx :
157
- # Zero node is a single "word" mention at the decimal position
158
- import math
159
- end_word = math .floor (zero_word_decimal_id ) + 1
160
- mentions .append (CorefMention (zero_sent_id , zero_word_decimal_id , end_word ))
161
- representative = mentions [best_span ]
162
- representative_text = extract_text (document , representative .sentence , representative .start_word , representative .end_word )
155
+ for span_idx , span in enumerate (span_cluster ):
156
+ word_idx = results .word_clusters [cluster_idx ][span_idx ]
157
+ is_zero = zero_nodes_created .get ((cluster_idx , word_idx ))
158
+ if is_zero :
159
+ (sent_id , zero_word_id ) = is_zero
160
+ # if the word id is a tuple, it will be attached
161
+ # to the zero
162
+ mentions .append (
163
+ CorefMention (
164
+ sent_id ,
165
+ zero_word_id ,
166
+ zero_word_id
167
+ )
168
+ )
169
+ else :
170
+ sent_id = sent_ids [span [0 ]]
171
+ start_word = word_pos [span [0 ]]
172
+ end_word = word_pos [span [1 ]- 1 ] + 1
173
+ mentions .append (CorefMention (sent_id , start_word , end_word ))
174
+
175
+ # if we ended up with no best span, then our "representative text"
176
+ # is just underscore
177
+ if best_span :
178
+ representative = mentions [best_span ]
179
+ representative_text = extract_text (document , representative .sentence , representative .start_word , representative .end_word )
180
+ else :
181
+ representative_text = "_"
163
182
164
183
chain = CorefChain (len (clusters ), mentions , representative_text , best_span )
165
184
clusters .append (chain )
@@ -173,15 +192,26 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):
173
192
return
174
193
175
194
zero_scores = results .zero_scores .squeeze (- 1 ) if results .zero_scores .dim () > 1 else results .zero_scores
195
+ is_zero = []
176
196
177
197
# Flatten word_clusters to get the word indices that correspond to zero_scores
178
198
cluster_word_ids = []
179
- for cluster in results .word_clusters :
199
+ cluster_mapping = {}
200
+ counter = 0
201
+ for indx , cluster in enumerate (results .word_clusters ):
202
+ for _ in range (len (cluster )):
203
+ cluster_mapping [counter ] = indx
204
+ counter += 1
180
205
cluster_word_ids .extend (cluster )
181
206
182
207
# Find indices where zero_scores > 0
183
- zero_indices = (zero_scores > 0 ).nonzero (as_tuple = True )[0 ]
184
-
208
+ print (zero_scores )
209
+ zero_indices = (zero_scores > 0.0 ).nonzero ()
210
+
211
+ # this dict maps (cluster_id, word_id) to (cluster_id, start, end)
212
+ # which overrides span_clusters
213
+ zero_to_coref = {}
214
+
185
215
for zero_idx in zero_indices :
186
216
zero_idx = zero_idx .item ()
187
217
if zero_idx >= len (cluster_word_ids ):
@@ -193,17 +223,21 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):
193
223
194
224
# Create zero node - attach BEFORE the current word
195
225
# This means the zero node comes after word_id-1 but before word_id
196
- if word_id > 0 :
197
- zero_word_id = (word_id , 1 ) # attach after word_id-1, before word_id
198
- zero_word = Word (document .sentences [sent_id ], {
199
- "text" : "_" ,
200
- "lemma" : "_" ,
201
- "id" : zero_word_id
202
- })
203
- document .sentences [sent_id ]._empty_words .append (zero_word )
204
-
205
- # Track this zero node for adding to coreference clusters
206
- cluster_idx , _ = cluster_mapping [zero_idx ]
207
- zero_nodes_created .append ((cluster_idx , sent_id , word_id + 0.1 ))
208
-
209
- return zero_nodes_created
226
+ zero_word_id = (
227
+ word_id ,
228
+ len (document .sentences [sent_id ]._empty_words )+ 1
229
+ ) # attach after word_id-1, before word_id
230
+ zero_word = Word (document .sentences [sent_id ], {
231
+ "text" : "_" ,
232
+ "lemma" : "_" ,
233
+ "id" : zero_word_id
234
+ })
235
+ document .sentences [sent_id ]._empty_words .append (zero_word )
236
+
237
+ # Track this zero node for adding to coreference clusters
238
+ cluster_idx = cluster_mapping [zero_idx ]
239
+ zero_to_coref [(cluster_idx , word_idx )] = (
240
+ sent_id , zero_word_id
241
+ )
242
+
243
+ return zero_to_coref
0 commit comments