Skip to content

Commit 6875446

Browse files
chg: Extend prodigy_to_tsv to handle multiple labels
1 parent a0869ca commit 6875446

File tree

1 file changed

+145
-11
lines changed

1 file changed

+145
-11
lines changed

deep_reference_parser/prodigy/prodigy_to_tsv.py

Lines changed: 145 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
import csv
1010
import re
11+
import sys
12+
from functools import reduce
1113

1214
import numpy as np
1315
import plac
14-
15-
from wasabi import Printer
16+
from wasabi import Printer, table
1617

1718
from ..io import read_jsonl
1819
from ..logger import logger
@@ -182,9 +183,81 @@ def yield_token_label_pair(self, doc, lists=False):
182183

183184
token_counter += 1
184185

186+
def get_document_hashes(dataset):
187+
"""Get the hashes for every doc in a dataset and return as set
188+
"""
189+
return set([doc["_input_hash"] for doc in dataset])
190+
191+
def check_all_equal(lst):
192+
"""Check that all items in a list are equal and return True or False
193+
"""
194+
return not lst or lst.count(lst[0]) == len(lst)
195+
196+
def hash_matches(doc, hash):
197+
"""Check whether the hash of the passed doc matches the passed hash
198+
"""
199+
return doc["_input_hash"] == hash
200+
201+
def get_doc_by_hash(dataset, hash):
202+
"""Return a doc from a dataset where hash matches doc["_input_hash"]
203+
Assumes there will only be one match!
204+
"""
205+
return [doc for doc in dataset if doc["_input_hash"] == hash][0]
206+
207+
def get_tokens(doc):
208+
return [token["text"] for token in doc["tokens"]]
209+
210+
def check_inputs(annotated_data):
211+
"""Checks whether two prodigy datasets contain the same docs (evaluated by
212+
doc["_input_hash"] and whether those docs contain the same tokens. This is
213+
essential to ensure that two independently labelled datasets are compatible.
214+
If they are not, an error is raised with an informative errors message.
215+
216+
Args:
217+
annotated_data (list): List of datasets in prodigy format that have
218+
been labelled with token level spans. Hence len(tokens)==len(spans).
219+
"""
220+
221+
doc_hashes = list(map(get_document_hashes, annotated_data))
222+
223+
# Check whether there are the same docs between datasets, and if
224+
# not return information on which ones are missing.
225+
226+
if not check_all_equal(doc_hashes):
227+
msg.fail("Some documents missing from one of the input datasets")
228+
229+
for i in range(len(doc_hashes)):
230+
for j in range(i + 1, len(doc_hashes)):
231+
diff = set(doc_hashes[i]) ^ set(doc_hashes[j])
232+
233+
if diff:
234+
msg.fail(f"Docs {diff} unequal between dataset {i} and {j}", exits=1)
235+
236+
# Check that the tokens between the splitting and parsing docs match
237+
238+
for hash in doc_hashes[0]:
239+
240+
hash_matches = list(map(lambda x: get_doc_by_hash(x, hash), annotated_data))
241+
tokens = list(map(get_tokens, hash_matches))
242+
243+
if not check_all_equal(tokens):
244+
msg.fail(f"Token mismatch for document {hash}", exits=1)
245+
246+
return True
247+
248+
def sort_docs_list(lst):
249+
"""Sort a list of prodigy docs by input hash
250+
"""
251+
return sorted(lst, key=lambda k: k['_input_hash'])
252+
253+
def combine_token_label_pairs(pairs):
254+
"""Combines a list of [(token, label), (token, label)] to give
255+
(token,label,label).
256+
"""
257+
return pairs[0][0:] + tuple(pair[1] for pair in pairs[1:])
185258

186259
@plac.annotations(
187-
input_file=("Path to jsonl file containing prodigy docs.", "positional", None, str),
260+
input_files=("Comma separated list of paths to jsonl files containing prodigy docs.", "positional", None, str),
188261
output_file=("Path to output tsv file.", "positional", None, str),
189262
respect_lines=(
190263
"Respect line endings? Or parse entire document in a single string?",
@@ -201,32 +274,93 @@ def yield_token_label_pair(self, doc, lists=False):
201274
line_limit=("Number of characters to include on a line", "option", "l", int),
202275
)
203276
def prodigy_to_tsv(
204-
input_file, output_file, respect_lines, respect_docs, line_limit=250
277+
input_files, output_file, respect_lines, respect_docs, line_limit=250
205278
):
206279
"""
207280
Convert token annotated jsonl to token annotated tsv ready for use in the
208-
Rodrigues model.
281+
deep_reference_parser model.
282+
283+
Will combine annotations from two jsonl files containing the same docs and
284+
the same tokens by comparing the "_input_hash" and token texts. If they are
285+
compatible, the output file will contain both labels ready for use in a
286+
multi-task model, for example:
287+
288+
token label label
289+
------------ ----- -----
290+
References o o
291+
o o
292+
1 o o
293+
. o o
294+
o o
295+
WHO title b-r
296+
treatment title i-r
297+
guidelines title i-r
298+
for title i-r
299+
drug title i-r
300+
- title i-r
301+
resistant title i-r
302+
tuberculosis title i-r
303+
, title i-r
304+
2016 title i-r
305+
306+
Multiple files must be passed as a comma separated list e.g.
307+
308+
python -m deep_reference_parser.prodigy prodigy_to_tsv file1.jsonl,file2.jsonl out.tsv
309+
209310
"""
210311

312+
input_files = input_files.split(",")
313+
314+
msg.info(f"Loading annotations from {len(input_files)} datasets")
211315
msg.info(f"Respect line endings: {respect_lines}")
212316
msg.info(f"Respect doc endings: {respect_docs}")
213317
msg.info(f"Line limit: {line_limit}")
214318

215-
annotated_data = read_jsonl(input_file)
319+
# Read the input_files. Note the use of map here, because we don't know
320+
# how many sets of annotations area being passed in the list. It could be 2
321+
# but in future it may be more.
216322

217-
logger.info("Loaded %s prodigy docs", len(annotated_data))
323+
annotated_data = list(map(read_jsonl, input_files))
324+
325+
# Check that the tokens match between sets of annotations. If not raise
326+
# errors and stop.
327+
328+
check_inputs(annotated_data)
329+
330+
# Sort the docs so that they are in the same order before converting to
331+
# token label pairs.
332+
333+
annotated_data = list(map(sort_docs_list, annotated_data))
218334

219335
tlp = TokenLabelPairs(
220336
respect_doc_endings=respect_docs,
221337
respect_line_endings=respect_lines,
222338
line_limit=line_limit,
223339
)
224-
token_label_pairs = list(tlp.run(annotated_data))
340+
341+
pairs_list = list(map(tlp.run, annotated_data))
342+
343+
# NOTE: Use of reduce to handle pairs_list of unknown length
344+
345+
if len(pairs_list) > 1:
346+
merged_pairs = (combine_token_label_pairs(pairs) for pairs in reduce(zip, pairs_list))
347+
example_pairs = [combine_token_label_pairs(pairs) for i, pairs in enumerate(reduce(zip, pairs_list)) if i < 15]
348+
else:
349+
merged_pairs = pairs_list[0]
350+
example_pairs = merged_pairs[0:14]
225351

226352
with open(output_file, "w") as fb:
227353
writer = csv.writer(fb, delimiter="\t")
228354
# Write DOCSTART and a blank line
229-
writer.writerows([("DOCSTART", None), (None, None)])
230-
writer.writerows(token_label_pairs)
355+
#writer.writerows([("DOCSTART", None), (None, None)])
356+
writer.writerows(merged_pairs)
357+
358+
# Print out the first ten rows as a sense check
359+
360+
msg.divider("Example output")
361+
header = ["token"] + ["label"] * len(annotated_data)
362+
aligns = ["r"] + ["l"] * len(annotated_data)
363+
formatted = table(example_pairs, header=header, divider=True, aligns=aligns)
364+
print(formatted)
231365

232-
logger.info("Wrote %s token/label pairs to %s", len(token_label_pairs), output_file)
366+
msg.good(f"Wrote token/label pairs to {output_file}")

0 commit comments

Comments
 (0)