|
10 | 10 |
|
11 | 11 | from stanza.utils.conll import CoNLL
|
12 | 12 |
|
| 13 | +import warnings |
13 | 14 | from random import Random
|
14 | 15 |
|
15 | 16 | import argparse
|
@@ -71,6 +72,7 @@ def process_documents(docs, augment=False):
|
71 | 72 | is_zero = []
|
72 | 73 | word_total = 0
|
73 | 74 | SPANS = re.compile(r"(\(\w+|[%\w]+\))")
|
| 75 | + do_ctn = False # if we broke in the loop |
74 | 76 | for parsed_sentence in doc.sentences:
|
75 | 77 | # spans regex
|
76 | 78 | # parse the misc column, leaving on "Entity" entries
|
@@ -130,6 +132,12 @@ def process_documents(docs, augment=False):
|
130 | 132 | # oo! thaht's a zero coref, we should merge it forwards
|
131 | 133 | # i.e. we pick the next word as the head!
|
132 | 134 | span = [span[0], span[1]+1, span[2]+1]
|
| 135 | + # crap! there's two zeros right next to each other |
| 136 | + # we are sad and confused so we give up in this case |
| 137 | + if len(sentence_text) > span[1] and sentence_text[span[1]] == "_": |
| 138 | + warnings.warn("Found two zeros next to each other in sequence; we are confused and therefore giving up.") |
| 139 | + do_ctn = True |
| 140 | + break |
133 | 141 | else:
|
134 | 142 | is_zero.append([span[0], False])
|
135 | 143 |
|
@@ -169,12 +177,46 @@ def process_documents(docs, augment=False):
|
169 | 177 | span_clusters[span[0]].append((span_start, span_end))
|
170 | 178 | word_clusters[span[0]].append(candidate_head)
|
171 | 179 | head2span.append((candidate_head, span_start, span_end))
|
| 180 | + if do_ctn: |
| 181 | + break |
172 | 182 | word_total += len(parsed_sentence.all_words)
|
| 183 | + if do_ctn: |
| 184 | + continue |
173 | 185 | span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])
|
174 | 186 | word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])
|
175 | 187 | head2span = sorted(head2span)
|
176 | 188 | is_zero = [i for _,i in sorted(is_zero)]
|
177 | 189 |
|
| 190 | + # remove zero tokens "_" from cased_words and adjust indices accordingly |
| 191 | + zero_positions = [i for i, w in enumerate(cased_words) if w == "_"] |
| 192 | + if zero_positions: |
| 193 | + old_to_new = {} |
| 194 | + new_idx = 0 |
| 195 | + for old_idx, w in enumerate(cased_words): |
| 196 | + if w != "_": |
| 197 | + old_to_new[old_idx] = new_idx |
| 198 | + new_idx += 1 |
| 199 | + cased_words = [w for w in cased_words if w != "_"] |
| 200 | + sent_id = [sent_id[i] for i in sorted(old_to_new.keys())] |
| 201 | + deprel = [deprel[i] for i in sorted(old_to_new.keys())] |
| 202 | + heads = [heads[i] for i in sorted(old_to_new.keys())] |
| 203 | + try: |
| 204 | + span_clusters = [ |
| 205 | + [(old_to_new[start], old_to_new[end - 1] + 1) for start, end in cluster] |
| 206 | + for cluster in span_clusters |
| 207 | + ] |
| 208 | + except: |
| 209 | + warnings.warn("Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up.") |
| 210 | + continue |
| 211 | + word_clusters = [ |
| 212 | + [old_to_new[h] for h in cluster] |
| 213 | + for cluster in word_clusters |
| 214 | + ] |
| 215 | + head2span = [ |
| 216 | + (old_to_new[h], old_to_new[s], old_to_new[e - 1] + 1) |
| 217 | + for h, s, e in head2span |
| 218 | + ] |
| 219 | + |
178 | 220 | processed = {
|
179 | 221 | "document_id": doc_id,
|
180 | 222 | "cased_words": cased_words,
|
@@ -338,4 +380,3 @@ def main():
|
338 | 380 | ["./extern_data/coref/corefud_v1_3/hu_szegedkoref-corefud-dev.conllu"],
|
339 | 381 | ["./extern_data/coref/corefud_v1_3/hu_szegedkoref-corefud-dev.conllu"]
|
340 | 382 | )
|
341 |
| - |
|
0 commit comments