Skip to content

Commit be2f434

Browse files
committed
Add a version of the HE coref conversion script which mixes in udcoref, which seems to help with the results
1 parent 767d463 commit be2f434

File tree

3 files changed

+67
-7
lines changed

3 files changed

+67
-7
lines changed

stanza/utils/datasets/coref/convert_hebrew_iahlt.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
52 F1, whereas if we use roberta-xlm, we get 50.
2222
"""
2323

24+
import argparse
2425
from collections import defaultdict, namedtuple
2526
import json
2627
import os
@@ -141,8 +142,15 @@ def write_json_file(output_filename, dataset):
141142
with open(output_filename, "w", encoding="utf-8") as fout:
142143
json.dump(dataset, fout, indent=2, ensure_ascii=False)
143144

144-
def main():
145+
def main(args=None):
145146
paths = get_default_paths()
147+
parser = argparse.ArgumentParser(
148+
prog='Convert Hebrew IAHLT data',
149+
)
150+
parser.add_argument('--output_directory', default=None, type=str, help='Where to output the data (defaults to %s)' % paths['COREF_DATA_DIR'])
151+
args = parser.parse_args(args=args)
152+
coref_output_path = args.output_directory if args.output_directory else paths['COREF_DATA_DIR']
153+
print("Will write IAHLT dataset to %s" % coref_output_path)
146154

147155
coref_input_path = paths["COREF_BASE"]
148156
hebrew_base_path = os.path.join(coref_input_path, "hebrew", "coref", "train_val_test")
@@ -158,8 +166,10 @@ def main():
158166
docs = read_doc(tokenizer, input_filename)
159167
dataset = [process_document(pipe, doc.doc_id, "", doc.sentences, doc.coref_spans, None, lang="he") for doc in tqdm(docs)]
160168

161-
output_filename = os.path.join(paths["COREF_DATA_DIR"], output_filename)
169+
output_filename = os.path.join(coref_output_path, output_filename)
162170
write_json_file(output_filename, dataset)
163171

172+
return output_files
173+
164174
if __name__ == '__main__':
165175
main()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Build a dataset mixed with IAHLT Hebrew and UD Coref
3+
4+
We find that the IAHLT dataset by itself, trained using Stanza 1.11
5+
with xlm-roberta-large and a lora finetuning layer, gets 49.7 F1.
6+
This is a bit lower than the value the IAHLT group originally had, as
7+
they reported 52. Interestingly, we find that mixing in the 1.3 UD
8+
Coref improves results, getting 51.7 under the same parameters
9+
10+
This script runs the IAHLT conversion and the UD Coref conversion,
11+
then combines the files into one big training file
12+
"""
13+
14+
import json
15+
import os
16+
import shutil
17+
import tempfile
18+
19+
from stanza.utils.datasets.coref import convert_hebrew_iahlt
20+
from stanza.utils.datasets.coref import convert_udcoref
21+
from stanza.utils.default_paths import get_default_paths
22+
23+
def main():
24+
paths = get_default_paths()
25+
coref_output_path = paths['COREF_DATA_DIR']
26+
with tempfile.TemporaryDirectory() as temp_dir_path:
27+
hebrew_filenames = convert_hebrew_iahlt.main(["--output_directory", temp_dir_path])
28+
udcoref_filenames = convert_udcoref.main(["--project", "gerrom", "--output_directory", temp_dir_path])
29+
30+
with open(os.path.join(temp_dir_path, hebrew_filenames[0]), encoding="utf-8") as fin:
31+
hebrew_train = json.load(fin)
32+
udcoref_train_filename = os.path.join(temp_dir_path, udcoref_filenames[0])
33+
with open(udcoref_train_filename, encoding="utf-8") as fin:
34+
print("Reading extra udcoref json data from %s" % udcoref_train_filename)
35+
udcoref_train = json.load(fin)
36+
mixed_train = hebrew_train + udcoref_train
37+
with open(os.path.join(coref_output_path, "he_mixed.train.json"), "w", encoding="utf-8") as fout:
38+
json.dump(mixed_train, fout, indent=2, ensure_ascii=False))
39+
40+
shutil.copyfile(os.path.join(temp_dir_path, hebrew_filenames[1]),
41+
os.path.join(coref_output_path, "he_mixed.dev.json"))
42+
shutil.copyfile(os.path.join(temp_dir_path, hebrew_filenames[2]),
43+
os.path.join(coref_output_path, "he_mixed.test.json"))
44+
45+
if __name__ == '__main__':
46+
main()

stanza/utils/datasets/coref/convert_udcoref.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,16 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_
282282
sections.append(full_test_section)
283283

284284

285+
output_filenames = []
285286
for section_data, section_name in zip(sections, section_names):
286287
converted_section = process_documents(section_data, augment=(section_name=="train"))
287288

288289
os.makedirs(coref_output_path, exist_ok=True)
289-
output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section_name))
290+
output_filenames.append("%s.%s.json" % (short_name, section_name))
291+
output_filename = os.path.join(coref_output_path, output_filenames[-1])
290292
with open(output_filename, "w", encoding="utf-8") as fout:
291293
json.dump(converted_section, fout, indent=2)
294+
return output_filenames
292295

293296
def get_dataset_by_language(coref_input_path, langs):
294297
conll_path = os.path.join(coref_input_path, "CorefUD-1.3-public", "data")
@@ -301,21 +304,22 @@ def get_dataset_by_language(coref_input_path, langs):
301304
dev_filenames = sorted(dev_filenames)
302305
return train_filenames, dev_filenames
303306

304-
def main():
307+
def main(args=None):
305308
paths = get_default_paths()
306309
parser = argparse.ArgumentParser(
307310
prog='Convert UDCoref Data',
308311
)
309312
parser.add_argument('--split_test', default=None, type=float, help='How much of the data to randomly split from train to make a test set')
313+
parser.add_argument('--output_directory', default=None, type=str, help='Where to output the data (defaults to %s)' % paths['COREF_DATA_DIR'])
310314

311315
group = parser.add_mutually_exclusive_group(required=True)
312316
group.add_argument('--directory', type=str, help="the name of the subfolder for data conversion")
313317
group.add_argument('--project', type=str, help="Look for and use a set of datasets for data conversion - Slavic or Hungarian")
314318
group.add_argument('--languages', type=str, help="Only use these specific languages from the coref directory")
315319

316-
args = parser.parse_args()
320+
args = parser.parse_args(args=args)
317321
coref_input_path = paths['COREF_BASE']
318-
coref_output_path = paths['COREF_DATA_DIR']
322+
coref_output_path = args.output_directory if args.output_directory else paths['COREF_DATA_DIR']
319323

320324
if args.languages:
321325
langs = args.languages.split(",")
@@ -369,7 +373,7 @@ def main():
369373
conll_path = args.directory
370374
train_filenames = sorted(glob.glob(os.path.join(conll_path, f"*train.conllu")))
371375
dev_filenames = sorted(glob.glob(os.path.join(conll_path, f"*dev.conllu")))
372-
process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames)
376+
return process_dataset(project, coref_output_path, args.split_test, train_filenames, dev_filenames)
373377

374378
if __name__ == '__main__':
375379
main()

0 commit comments

Comments
 (0)