88
99import csv
1010import re
11+ import sys
12+ from functools import reduce
1113
1214import numpy as np
1315import plac
14-
15- from wasabi import Printer
16+ from wasabi import Printer , table
1617
1718from ..io import read_jsonl
1819from ..logger import logger
1920
2021msg = Printer ()
2122
23+ ROWS_TO_PRINT = 15
24+
2225
2326class TokenLabelPairs :
2427 """
@@ -183,8 +186,96 @@ def yield_token_label_pair(self, doc, lists=False):
183186 token_counter += 1
184187
185188
189+ def get_document_hashes (dataset ):
190+ """Get the hashes for every doc in a dataset and return as set
191+ """
192+ return set ([doc ["_input_hash" ] for doc in dataset ])
193+
194+
195+ def check_all_equal (lst ):
196+ """Check that all items in a list are equal and return True or False
197+ """
198+ return not lst or lst .count (lst [0 ]) == len (lst )
199+
200+
201+ def hash_matches (doc , hash ):
202+ """Check whether the hash of the passed doc matches the passed hash
203+ """
204+ return doc ["_input_hash" ] == hash
205+
206+
207+ def get_doc_by_hash (dataset , hash ):
208+ """Return a doc from a dataset where hash matches doc["_input_hash"]
209+ Assumes there will only be one match!
210+ """
211+ return [doc for doc in dataset if doc ["_input_hash" ] == hash ][0 ]
212+
213+
214+ def get_tokens (doc ):
215+ return [token ["text" ] for token in doc ["tokens" ]]
216+
217+
218+ def check_inputs (annotated_data ):
219+ """Checks whether two prodigy datasets contain the same docs (evaluated by
220+ doc["_input_hash"] and whether those docs contain the same tokens. This is
221+ essential to ensure that two independently labelled datasets are compatible.
222+ If they are not, an error is raised with an informative errors message.
223+
224+ Args:
225+ annotated_data (list): List of datasets in prodigy format that have
226+ been labelled with token level spans. Hence len(tokens)==len(spans).
227+ """
228+
229+ doc_hashes = list (map (get_document_hashes , annotated_data ))
230+
231+ # Check whether there are the same docs between datasets, and if
232+ # not return information on which ones are missing.
233+
234+ if not check_all_equal (doc_hashes ):
235+ msg .fail ("Some documents missing from one of the input datasets" )
236+
237+ for i in range (len (doc_hashes )):
238+ for j in range (i + 1 , len (doc_hashes )):
239+ diff = set (doc_hashes [i ]) ^ set (doc_hashes [j ])
240+
241+ if diff :
242+ msg .fail (
243+ f"Docs { diff } unequal between dataset { i } and { j } " , exits = 1
244+ )
245+
246+ # Check that the tokens between the splitting and parsing docs match
247+
248+ for hash in doc_hashes [0 ]:
249+
250+ hash_matches = list (map (lambda x : get_doc_by_hash (x , hash ), annotated_data ))
251+ tokens = list (map (get_tokens , hash_matches ))
252+
253+ if not check_all_equal (tokens ):
254+ msg .fail (f"Token mismatch for document { hash } " , exits = 1 )
255+
256+ return True
257+
258+
259+ def sort_docs_list (lst ):
260+ """Sort a list of prodigy docs by input hash
261+ """
262+ return sorted (lst , key = lambda k : k ["_input_hash" ])
263+
264+
265+ def combine_token_label_pairs (pairs ):
266+ """Combines a list of [(token, label), (token, label)] to give
267+ (token,label,label).
268+ """
269+ return pairs [0 ][0 :] + tuple (pair [1 ] for pair in pairs [1 :])
270+
271+
186272@plac .annotations (
187- input_file = ("Path to jsonl file containing prodigy docs." , "positional" , None , str ),
273+ input_files = (
274+ "Comma separated list of paths to jsonl files containing prodigy docs." ,
275+ "positional" ,
276+ None ,
277+ str ,
278+ ),
188279 output_file = ("Path to output tsv file." , "positional" , None , str ),
189280 respect_lines = (
190281 "Respect line endings? Or parse entire document in a single string?" ,
@@ -201,32 +292,99 @@ def yield_token_label_pair(self, doc, lists=False):
201292 line_limit = ("Number of characters to include on a line" , "option" , "l" , int ),
202293)
203294def prodigy_to_tsv (
204- input_file , output_file , respect_lines , respect_docs , line_limit = 250
295+ input_files , output_file , respect_lines , respect_docs , line_limit = 250
205296):
206297 """
207298 Convert token annotated jsonl to token annotated tsv ready for use in the
208- Rodrigues model.
299+ deep_reference_parser model.
300+
301+ Will combine annotations from two jsonl files containing the same docs and
302+ the same tokens by comparing the "_input_hash" and token texts. If they are
303+ compatible, the output file will contain both labels ready for use in a
304+ multi-task model, for example:
305+
306+ token label label
307+ ------------ ----- -----
308+ References o o
309+ o o
310+ 1 o o
311+ . o o
312+ o o
313+ WHO title b-r
314+ treatment title i-r
315+ guidelines title i-r
316+ for title i-r
317+ drug title i-r
318+ - title i-r
319+ resistant title i-r
320+ tuberculosis title i-r
321+ , title i-r
322+ 2016 title i-r
323+
324+ Multiple files must be passed as a comma separated list e.g.
325+
326+ python -m deep_reference_parser.prodigy prodigy_to_tsv file1.jsonl,file2.jsonl out.tsv
327+
209328 """
210329
330+ input_files = input_files .split ("," )
331+
332+ msg .info (f"Loading annotations from { len (input_files )} datasets" )
211333 msg .info (f"Respect line endings: { respect_lines } " )
212334 msg .info (f"Respect doc endings: { respect_docs } " )
213335 msg .info (f"Line limit: { line_limit } " )
214336
215- annotated_data = read_jsonl (input_file )
337+ # Read the input_files. Note the use of map here, because we don't know
338+ # how many sets of annotations area being passed in the list. It could be 2
339+ # but in future it may be more.
340+
341+ annotated_data = list (map (read_jsonl , input_files ))
342+
343+ # Check that the tokens match between sets of annotations. If not raise
344+ # errors and stop.
345+
346+ check_inputs (annotated_data )
216347
217- logger .info ("Loaded %s prodigy docs" , len (annotated_data ))
348+ # Sort the docs so that they are in the same order before converting to
349+ # token label pairs.
350+
351+ annotated_data = list (map (sort_docs_list , annotated_data ))
218352
219353 tlp = TokenLabelPairs (
220354 respect_doc_endings = respect_docs ,
221355 respect_line_endings = respect_lines ,
222356 line_limit = line_limit ,
223357 )
224- token_label_pairs = list (tlp .run (annotated_data ))
358+
359+ pairs_list = list (map (tlp .run , annotated_data ))
360+
361+ # NOTE: Use of reduce to handle pairs_list of unknown length
362+
363+ if len (pairs_list ) > 1 :
364+ merged_pairs = (
365+ combine_token_label_pairs (pairs ) for pairs in reduce (zip , pairs_list )
366+ )
367+ example_pairs = [
368+ combine_token_label_pairs (pairs )
369+ for i , pairs in enumerate (reduce (zip , pairs_list ))
370+ if i < ROWS_TO_PRINT
371+ ]
372+ else :
373+ merged_pairs = pairs_list [0 ]
374+ example_pairs = merged_pairs [0 :ROWS_TO_PRINT ]
225375
226376 with open (output_file , "w" ) as fb :
227377 writer = csv .writer (fb , delimiter = "\t " )
228378 # Write DOCSTART and a blank line
229- writer .writerows ([("DOCSTART" , None ), (None , None )])
230- writer .writerows (token_label_pairs )
379+ # writer.writerows([("DOCSTART", None), (None, None)])
380+ writer .writerows (merged_pairs )
381+
382+ # Print out the first ten rows as a sense check
383+
384+ msg .divider ("Example output" )
385+ header = ["token" ] + ["label" ] * len (annotated_data )
386+ aligns = ["r" ] + ["l" ] * len (annotated_data )
387+ formatted = table (example_pairs , header = header , divider = True , aligns = aligns )
388+ print (formatted )
231389
232- logger . info ( "Wrote %s token/label pairs to %s" , len ( token_label_pairs ), output_file )
390+ msg . good ( f "Wrote token/label pairs to { output_file } " )
0 commit comments