10
10
11
11
from stanza .utils .conll import CoNLL
12
12
13
+ import warnings
13
14
from random import Random
14
15
15
16
import argparse
22
23
UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1
23
24
24
25
def process_documents (docs , augment = False ):
26
+ # docs = sections
25
27
processed_section = []
26
28
27
29
for idx , (doc , doc_id , lang ) in enumerate (tqdm (docs )):
@@ -67,8 +69,10 @@ def process_documents(docs, augment=False):
67
69
span_clusters = defaultdict (list )
68
70
word_clusters = defaultdict (list )
69
71
head2span = []
72
+ is_zero = []
70
73
word_total = 0
71
74
SPANS = re .compile (r"(\(\w+|[%\w]+\))" )
75
+ do_ctn = False # if we broke in the loop
72
76
for parsed_sentence in doc .sentences :
73
77
# spans regex
74
78
# parse the misc column, leaving on "Entity" entries
@@ -114,8 +118,29 @@ def process_documents(docs, augment=False):
114
118
coref_spans .append ([int (k ), i [0 ], i [1 ]])
115
119
sentence_upos = [x .upos for x in parsed_sentence .all_words ]
116
120
sentence_heads = [x .head - 1 if x .head and x .head > 0 else None for x in parsed_sentence .all_words ]
121
+ sentence_text = [x .text for x in parsed_sentence .all_words ]
122
+
123
+ # if "_" in sentence_text and sentence_text.index("_") in [j for i in coref_spans for j in i]:
124
+ # import ipdb
125
+ # ipdb.set_trace()
117
126
118
127
for span in coref_spans :
128
+ zero = False
129
+ if sentence_text [span [1 ]] == "_" and span [1 ] == span [2 ]:
130
+ is_zero .append ([span [0 ], True ])
131
+ zero = True
132
+ # oo! thaht's a zero coref, we should merge it forwards
133
+ # i.e. we pick the next word as the head!
134
+ span = [span [0 ], span [1 ]+ 1 , span [2 ]+ 1 ]
135
+ # crap! there's two zeros right next to each other
136
+ # we are sad and confused so we give up in this case
137
+ if len (sentence_text ) > span [1 ] and sentence_text [span [1 ]] == "_" :
138
+ warnings .warn ("Found two zeros next to each other in sequence; we are confused and therefore giving up." )
139
+ do_ctn = True
140
+ break
141
+ else :
142
+ is_zero .append ([span [0 ], False ])
143
+
119
144
# input is expected to be start word, end word + 1
120
145
# counting from 0
121
146
# whereas the OntoNotes coref_span is [start_word, end_word] inclusive
@@ -124,10 +149,13 @@ def process_documents(docs, augment=False):
124
149
# if its a zero coref (i.e. coref, but the head in None), we call
125
150
# the beginning of the span (i.e. the zero itself) the head
126
151
127
- try :
128
- candidate_head = find_cconj_head (sentence_heads , sentence_upos , span [1 ], span [2 ]+ 1 )
129
- except RecursionError :
152
+ if zero :
130
153
candidate_head = span [1 ]
154
+ else :
155
+ try :
156
+ candidate_head = find_cconj_head (sentence_heads , sentence_upos , span [1 ], span [2 ]+ 1 )
157
+ except RecursionError :
158
+ candidate_head = span [1 ]
131
159
132
160
if candidate_head is None :
133
161
for candidate_head in range (span [1 ], span [2 ] + 1 ):
@@ -149,10 +177,45 @@ def process_documents(docs, augment=False):
149
177
span_clusters [span [0 ]].append ((span_start , span_end ))
150
178
word_clusters [span [0 ]].append (candidate_head )
151
179
head2span .append ((candidate_head , span_start , span_end ))
180
+ if do_ctn :
181
+ break
152
182
word_total += len (parsed_sentence .all_words )
183
+ if do_ctn :
184
+ continue
153
185
span_clusters = sorted ([sorted (values ) for _ , values in span_clusters .items ()])
154
186
word_clusters = sorted ([sorted (values ) for _ , values in word_clusters .items ()])
155
187
head2span = sorted (head2span )
188
+ is_zero = [i for _ ,i in sorted (is_zero )]
189
+
190
+ # remove zero tokens "_" from cased_words and adjust indices accordingly
191
+ zero_positions = [i for i , w in enumerate (cased_words ) if w == "_" ]
192
+ if zero_positions :
193
+ old_to_new = {}
194
+ new_idx = 0
195
+ for old_idx , w in enumerate (cased_words ):
196
+ if w != "_" :
197
+ old_to_new [old_idx ] = new_idx
198
+ new_idx += 1
199
+ cased_words = [w for w in cased_words if w != "_" ]
200
+ sent_id = [sent_id [i ] for i in sorted (old_to_new .keys ())]
201
+ deprel = [deprel [i ] for i in sorted (old_to_new .keys ())]
202
+ heads = [heads [i ] for i in sorted (old_to_new .keys ())]
203
+ try :
204
+ span_clusters = [
205
+ [(old_to_new [start ], old_to_new [end - 1 ] + 1 ) for start , end in cluster ]
206
+ for cluster in span_clusters
207
+ ]
208
+ except :
209
+ warnings .warn ("Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up." )
210
+ continue
211
+ word_clusters = [
212
+ [old_to_new [h ] for h in cluster ]
213
+ for cluster in word_clusters
214
+ ]
215
+ head2span = [
216
+ (old_to_new [h ], old_to_new [s ], old_to_new [e - 1 ] + 1 )
217
+ for h , s , e in head2span
218
+ ]
156
219
157
220
processed = {
158
221
"document_id" : doc_id ,
@@ -165,7 +228,8 @@ def process_documents(docs, augment=False):
165
228
"span_clusters" : span_clusters ,
166
229
"word_clusters" : word_clusters ,
167
230
"head2span" : head2span ,
168
- "lang" : lang
231
+ "lang" : lang ,
232
+ "is_zero" : is_zero
169
233
}
170
234
processed_section .append (processed )
171
235
return processed_section
@@ -183,6 +247,7 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_
183
247
lang = load .split ("/" )[- 1 ].split ("_" )[0 ]
184
248
print ("Ingesting %s from %s of lang %s" % (section , load , lang ))
185
249
docs = CoNLL .conll2multi_docs (load , ignore_gapping = False )
250
+ # sections = docs[:10]
186
251
print (" Ingested %d documents" % len (docs ))
187
252
if split_test and section == 'train' :
188
253
test_section = []
@@ -303,4 +368,3 @@ def main():
303
368
304
369
if __name__ == '__main__' :
305
370
main ()
306
-
0 commit comments