88
99import csv
1010import re
11+ import sys
12+ from functools import reduce
1113
1214import numpy as np
1315import plac
14-
15- from wasabi import Printer
16+ from wasabi import Printer , table
1617
1718from ..io import read_jsonl
1819from ..logger import logger
@@ -182,9 +183,81 @@ def yield_token_label_pair(self, doc, lists=False):
182183
183184 token_counter += 1
184185
186+ def get_document_hashes (dataset ):
187+ """Get the hashes for every doc in a dataset and return as set
188+ """
189+ return set ([doc ["_input_hash" ] for doc in dataset ])
190+
191+ def check_all_equal (lst ):
192+ """Check that all items in a list are equal and return True or False
193+ """
194+ return not lst or lst .count (lst [0 ]) == len (lst )
195+
196+ def hash_matches (doc , hash ):
197+ """Check whether the hash of the passed doc matches the passed hash
198+ """
199+ return doc ["_input_hash" ] == hash
200+
201+ def get_doc_by_hash (dataset , hash ):
202+ """Return a doc from a dataset where hash matches doc["_input_hash"]
203+ Assumes there will only be one match!
204+ """
205+ return [doc for doc in dataset if doc ["_input_hash" ] == hash ][0 ]
206+
207+ def get_tokens (doc ):
208+ return [token ["text" ] for token in doc ["tokens" ]]
209+
210+ def check_inputs (annotated_data ):
211+ """Checks whether two prodigy datasets contain the same docs (evaluated by
212+ doc["_input_hash"] and whether those docs contain the same tokens. This is
213+ essential to ensure that two independently labelled datasets are compatible.
214+ If they are not, an error is raised with an informative errors message.
215+
216+ Args:
217+ annotated_data (list): List of datasets in prodigy format that have
218+ been labelled with token level spans. Hence len(tokens)==len(spans).
219+ """
220+
221+ doc_hashes = list (map (get_document_hashes , annotated_data ))
222+
223+ # Check whether there are the same docs between datasets, and if
224+ # not return information on which ones are missing.
225+
226+ if not check_all_equal (doc_hashes ):
227+ msg .fail ("Some documents missing from one of the input datasets" )
228+
229+ for i in range (len (doc_hashes )):
230+ for j in range (i + 1 , len (doc_hashes )):
231+ diff = set (doc_hashes [i ]) ^ set (doc_hashes [j ])
232+
233+ if diff :
234+ msg .fail (f"Docs { diff } unequal between dataset { i } and { j } " , exits = 1 )
235+
236+ # Check that the tokens between the splitting and parsing docs match
237+
238+ for hash in doc_hashes [0 ]:
239+
240+ hash_matches = list (map (lambda x : get_doc_by_hash (x , hash ), annotated_data ))
241+ tokens = list (map (get_tokens , hash_matches ))
242+
243+ if not check_all_equal (tokens ):
244+ msg .fail (f"Token mismatch for document { hash } " , exits = 1 )
245+
246+ return True
247+
248+ def sort_docs_list (lst ):
249+ """Sort a list of prodigy docs by input hash
250+ """
251+ return sorted (lst , key = lambda k : k ['_input_hash' ])
252+
253+ def combine_token_label_pairs (pairs ):
254+ """Combines a list of [(token, label), (token, label)] to give
255+ (token,label,label).
256+ """
257+ return pairs [0 ][0 :] + tuple (pair [1 ] for pair in pairs [1 :])
185258
186259@plac .annotations (
187- input_file = ("Path to jsonl file containing prodigy docs." , "positional" , None , str ),
260+ input_files = ("Comma separated list of paths to jsonl files containing prodigy docs." , "positional" , None , str ),
188261 output_file = ("Path to output tsv file." , "positional" , None , str ),
189262 respect_lines = (
190263 "Respect line endings? Or parse entire document in a single string?" ,
@@ -201,32 +274,93 @@ def yield_token_label_pair(self, doc, lists=False):
201274 line_limit = ("Number of characters to include on a line" , "option" , "l" , int ),
202275)
203276def prodigy_to_tsv (
204- input_file , output_file , respect_lines , respect_docs , line_limit = 250
277+ input_files , output_file , respect_lines , respect_docs , line_limit = 250
205278):
206279 """
207280 Convert token annotated jsonl to token annotated tsv ready for use in the
208- Rodrigues model.
281+ deep_reference_parser model.
282+
283+ Will combine annotations from two jsonl files containing the same docs and
284+ the same tokens by comparing the "_input_hash" and token texts. If they are
285+ compatible, the output file will contain both labels ready for use in a
286+ multi-task model, for example:
287+
288+ token label label
289+ ------------ ----- -----
290+ References o o
291+ o o
292+ 1 o o
293+ . o o
294+ o o
295+ WHO title b-r
296+ treatment title i-r
297+ guidelines title i-r
298+ for title i-r
299+ drug title i-r
300+ - title i-r
301+ resistant title i-r
302+ tuberculosis title i-r
303+ , title i-r
304+ 2016 title i-r
305+
306+ Multiple files must be passed as a comma separated list e.g.
307+
308+ python -m deep_reference_parser.prodigy prodigy_to_tsv file1.jsonl,file2.jsonl out.tsv
309+
209310 """
210311
312+ input_files = input_files .split ("," )
313+
314+ msg .info (f"Loading annotations from { len (input_files )} datasets" )
211315 msg .info (f"Respect line endings: { respect_lines } " )
212316 msg .info (f"Respect doc endings: { respect_docs } " )
213317 msg .info (f"Line limit: { line_limit } " )
214318
215- annotated_data = read_jsonl (input_file )
319+ # Read the input_files. Note the use of map here, because we don't know
320+ # how many sets of annotations area being passed in the list. It could be 2
321+ # but in future it may be more.
216322
217- logger .info ("Loaded %s prodigy docs" , len (annotated_data ))
323+ annotated_data = list (map (read_jsonl , input_files ))
324+
325+ # Check that the tokens match between sets of annotations. If not raise
326+ # errors and stop.
327+
328+ check_inputs (annotated_data )
329+
330+ # Sort the docs so that they are in the same order before converting to
331+ # token label pairs.
332+
333+ annotated_data = list (map (sort_docs_list , annotated_data ))
218334
219335 tlp = TokenLabelPairs (
220336 respect_doc_endings = respect_docs ,
221337 respect_line_endings = respect_lines ,
222338 line_limit = line_limit ,
223339 )
224- token_label_pairs = list (tlp .run (annotated_data ))
340+
341+ pairs_list = list (map (tlp .run , annotated_data ))
342+
343+ # NOTE: Use of reduce to handle pairs_list of unknown length
344+
345+ if len (pairs_list ) > 1 :
346+ merged_pairs = (combine_token_label_pairs (pairs ) for pairs in reduce (zip , pairs_list ))
347+ example_pairs = [combine_token_label_pairs (pairs ) for i , pairs in enumerate (reduce (zip , pairs_list )) if i < 15 ]
348+ else :
349+ merged_pairs = pairs_list [0 ]
350+ example_pairs = merged_pairs [0 :14 ]
225351
226352 with open (output_file , "w" ) as fb :
227353 writer = csv .writer (fb , delimiter = "\t " )
228354 # Write DOCSTART and a blank line
229- writer .writerows ([("DOCSTART" , None ), (None , None )])
230- writer .writerows (token_label_pairs )
355+ #writer.writerows([("DOCSTART", None), (None, None)])
356+ writer .writerows (merged_pairs )
357+
358+ # Print out the first ten rows as a sense check
359+
360+ msg .divider ("Example output" )
361+ header = ["token" ] + ["label" ] * len (annotated_data )
362+ aligns = ["r" ] + ["l" ] * len (annotated_data )
363+ formatted = table (example_pairs , header = header , divider = True , aligns = aligns )
364+ print (formatted )
231365
232- logger . info ( "Wrote %s token/label pairs to %s" , len ( token_label_pairs ), output_file )
366+ msg . good ( f "Wrote token/label pairs to { output_file } " )
0 commit comments