Skip to content

Commit 5639048

Browse files
committed
Merge branch 'corefud_v1.3' of github.com:stanfordnlp/stanza into corefud_v1.3
2 parents 5943ae4 + 1d0863d commit 5639048

File tree

1 file changed

+69
-5
lines changed

1 file changed

+69
-5
lines changed

stanza/utils/datasets/coref/convert_udcoref.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from stanza.utils.conll import CoNLL
1212

13+
import warnings
1314
from random import Random
1415

1516
import argparse
@@ -22,6 +23,7 @@
2223
UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1
2324

2425
def process_documents(docs, augment=False):
26+
# docs = sections
2527
processed_section = []
2628

2729
for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)):
@@ -67,8 +69,10 @@ def process_documents(docs, augment=False):
6769
span_clusters = defaultdict(list)
6870
word_clusters = defaultdict(list)
6971
head2span = []
72+
is_zero = []
7073
word_total = 0
7174
SPANS = re.compile(r"(\(\w+|[%\w]+\))")
75+
do_ctn = False # if we broke in the loop
7276
for parsed_sentence in doc.sentences:
7377
# spans regex
7478
# parse the misc column, leaving on "Entity" entries
@@ -114,8 +118,29 @@ def process_documents(docs, augment=False):
114118
coref_spans.append([int(k), i[0], i[1]])
115119
sentence_upos = [x.upos for x in parsed_sentence.all_words]
116120
sentence_heads = [x.head - 1 if x.head and x.head > 0 else None for x in parsed_sentence.all_words]
121+
sentence_text = [x.text for x in parsed_sentence.all_words]
122+
123+
# if "_" in sentence_text and sentence_text.index("_") in [j for i in coref_spans for j in i]:
124+
# import ipdb
125+
# ipdb.set_trace()
117126

118127
for span in coref_spans:
128+
zero = False
129+
if sentence_text[span[1]] == "_" and span[1] == span[2]:
130+
is_zero.append([span[0], True])
131+
zero = True
132+
# oo! thaht's a zero coref, we should merge it forwards
133+
# i.e. we pick the next word as the head!
134+
span = [span[0], span[1]+1, span[2]+1]
135+
# crap! there's two zeros right next to each other
136+
# we are sad and confused so we give up in this case
137+
if len(sentence_text) > span[1] and sentence_text[span[1]] == "_":
138+
warnings.warn("Found two zeros next to each other in sequence; we are confused and therefore giving up.")
139+
do_ctn = True
140+
break
141+
else:
142+
is_zero.append([span[0], False])
143+
119144
# input is expected to be start word, end word + 1
120145
# counting from 0
121146
# whereas the OntoNotes coref_span is [start_word, end_word] inclusive
@@ -124,10 +149,13 @@ def process_documents(docs, augment=False):
124149
# if its a zero coref (i.e. coref, but the head in None), we call
125150
# the beginning of the span (i.e. the zero itself) the head
126151

127-
try:
128-
candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1)
129-
except RecursionError:
152+
if zero:
130153
candidate_head = span[1]
154+
else:
155+
try:
156+
candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1)
157+
except RecursionError:
158+
candidate_head = span[1]
131159

132160
if candidate_head is None:
133161
for candidate_head in range(span[1], span[2] + 1):
@@ -149,10 +177,45 @@ def process_documents(docs, augment=False):
149177
span_clusters[span[0]].append((span_start, span_end))
150178
word_clusters[span[0]].append(candidate_head)
151179
head2span.append((candidate_head, span_start, span_end))
180+
if do_ctn:
181+
break
152182
word_total += len(parsed_sentence.all_words)
183+
if do_ctn:
184+
continue
153185
span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])
154186
word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])
155187
head2span = sorted(head2span)
188+
is_zero = [i for _,i in sorted(is_zero)]
189+
190+
# remove zero tokens "_" from cased_words and adjust indices accordingly
191+
zero_positions = [i for i, w in enumerate(cased_words) if w == "_"]
192+
if zero_positions:
193+
old_to_new = {}
194+
new_idx = 0
195+
for old_idx, w in enumerate(cased_words):
196+
if w != "_":
197+
old_to_new[old_idx] = new_idx
198+
new_idx += 1
199+
cased_words = [w for w in cased_words if w != "_"]
200+
sent_id = [sent_id[i] for i in sorted(old_to_new.keys())]
201+
deprel = [deprel[i] for i in sorted(old_to_new.keys())]
202+
heads = [heads[i] for i in sorted(old_to_new.keys())]
203+
try:
204+
span_clusters = [
205+
[(old_to_new[start], old_to_new[end - 1] + 1) for start, end in cluster]
206+
for cluster in span_clusters
207+
]
208+
except:
209+
warnings.warn("Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up.")
210+
continue
211+
word_clusters = [
212+
[old_to_new[h] for h in cluster]
213+
for cluster in word_clusters
214+
]
215+
head2span = [
216+
(old_to_new[h], old_to_new[s], old_to_new[e - 1] + 1)
217+
for h, s, e in head2span
218+
]
156219

157220
processed = {
158221
"document_id": doc_id,
@@ -165,7 +228,8 @@ def process_documents(docs, augment=False):
165228
"span_clusters": span_clusters,
166229
"word_clusters": word_clusters,
167230
"head2span": head2span,
168-
"lang": lang
231+
"lang": lang,
232+
"is_zero": is_zero
169233
}
170234
processed_section.append(processed)
171235
return processed_section
@@ -183,6 +247,7 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_
183247
lang = load.split("/")[-1].split("_")[0]
184248
print("Ingesting %s from %s of lang %s" % (section, load, lang))
185249
docs = CoNLL.conll2multi_docs(load, ignore_gapping=False)
250+
# sections = docs[:10]
186251
print(" Ingested %d documents" % len(docs))
187252
if split_test and section == 'train':
188253
test_section = []
@@ -303,4 +368,3 @@ def main():
303368

304369
if __name__ == '__main__':
305370
main()
306-

0 commit comments

Comments
 (0)