@@ -282,13 +282,16 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_
282282 sections .append (full_test_section )
283283
284284
285+ output_filenames = []
285286 for section_data , section_name in zip (sections , section_names ):
286287 converted_section = process_documents (section_data , augment = (section_name == "train" ))
287288
288289 os .makedirs (coref_output_path , exist_ok = True )
289- output_filename = os .path .join (coref_output_path , "%s.%s.json" % (short_name , section_name ))
290+ output_filenames .append ("%s.%s.json" % (short_name , section_name ))
291+ output_filename = os .path .join (coref_output_path , output_filenames [- 1 ])
290292 with open (output_filename , "w" , encoding = "utf-8" ) as fout :
291293 json .dump (converted_section , fout , indent = 2 )
294+ return output_filenames
292295
293296def get_dataset_by_language (coref_input_path , langs ):
294297 conll_path = os .path .join (coref_input_path , "CorefUD-1.3-public" , "data" )
@@ -301,21 +304,22 @@ def get_dataset_by_language(coref_input_path, langs):
301304 dev_filenames = sorted (dev_filenames )
302305 return train_filenames , dev_filenames
303306
304- def main ():
307+ def main (args = None ):
305308 paths = get_default_paths ()
306309 parser = argparse .ArgumentParser (
307310 prog = 'Convert UDCoref Data' ,
308311 )
309312 parser .add_argument ('--split_test' , default = None , type = float , help = 'How much of the data to randomly split from train to make a test set' )
313+ parser .add_argument ('--output_directory' , default = None , type = str , help = 'Where to output the data (defaults to %s)' % paths ['COREF_DATA_DIR' ])
310314
311315 group = parser .add_mutually_exclusive_group (required = True )
312316 group .add_argument ('--directory' , type = str , help = "the name of the subfolder for data conversion" )
313317 group .add_argument ('--project' , type = str , help = "Look for and use a set of datasets for data conversion - Slavic or Hungarian" )
314318 group .add_argument ('--languages' , type = str , help = "Only use these specific languages from the coref directory" )
315319
316- args = parser .parse_args ()
320+ args = parser .parse_args (args = args )
317321 coref_input_path = paths ['COREF_BASE' ]
318- coref_output_path = paths ['COREF_DATA_DIR' ]
322+ coref_output_path = args . output_directory if args . output_directory else paths ['COREF_DATA_DIR' ]
319323
320324 if args .languages :
321325 langs = args .languages .split ("," )
@@ -369,7 +373,7 @@ def main():
369373 conll_path = args .directory
370374 train_filenames = sorted (glob .glob (os .path .join (conll_path , f"*train.conllu" )))
371375 dev_filenames = sorted (glob .glob (os .path .join (conll_path , f"*dev.conllu" )))
372- process_dataset (project , coref_output_path , args .split_test , train_filenames , dev_filenames )
376+ return process_dataset (project , coref_output_path , args .split_test , train_filenames , dev_filenames )
373377
374378if __name__ == '__main__' :
375379 main ()
0 commit comments