Skip to content

Commit 7daf8e1

Browse files
Merge branch 'master' into feature/ivyleavedtoadflax/codecov
2 parents 959458a + ca12985 commit 7daf8e1

File tree

3 files changed

+223
-14
lines changed

3 files changed

+223
-14
lines changed

deep_reference_parser/prodigy/prodigy_to_tsv.py

Lines changed: 169 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,20 @@
88

99
import csv
1010
import re
11+
import sys
12+
from functools import reduce
1113

1214
import numpy as np
1315
import plac
14-
15-
from wasabi import Printer
16+
from wasabi import Printer, table
1617

1718
from ..io import read_jsonl
1819
from ..logger import logger
1920

2021
msg = Printer()
2122

23+
ROWS_TO_PRINT=15
24+
2225

2326
class TokenLabelPairs:
2427
"""
@@ -183,8 +186,96 @@ def yield_token_label_pair(self, doc, lists=False):
183186
token_counter += 1
184187

185188

189+
def get_document_hashes(dataset):
190+
"""Get the hashes for every doc in a dataset and return as set
191+
"""
192+
return set([doc["_input_hash"] for doc in dataset])
193+
194+
195+
def check_all_equal(lst):
196+
"""Check that all items in a list are equal and return True or False
197+
"""
198+
return not lst or lst.count(lst[0]) == len(lst)
199+
200+
201+
def hash_matches(doc, hash):
202+
"""Check whether the hash of the passed doc matches the passed hash
203+
"""
204+
return doc["_input_hash"] == hash
205+
206+
207+
def get_doc_by_hash(dataset, hash):
208+
"""Return a doc from a dataset where hash matches doc["_input_hash"]
209+
Assumes there will only be one match!
210+
"""
211+
return [doc for doc in dataset if doc["_input_hash"] == hash][0]
212+
213+
214+
def get_tokens(doc):
215+
return [token["text"] for token in doc["tokens"]]
216+
217+
218+
def check_inputs(annotated_data):
219+
"""Checks whether two prodigy datasets contain the same docs (evaluated by
220+
doc["_input_hash"] and whether those docs contain the same tokens. This is
221+
essential to ensure that two independently labelled datasets are compatible.
222+
If they are not, an error is raised with an informative errors message.
223+
224+
Args:
225+
annotated_data (list): List of datasets in prodigy format that have
226+
been labelled with token level spans. Hence len(tokens)==len(spans).
227+
"""
228+
229+
doc_hashes = list(map(get_document_hashes, annotated_data))
230+
231+
# Check whether there are the same docs between datasets, and if
232+
# not return information on which ones are missing.
233+
234+
if not check_all_equal(doc_hashes):
235+
msg.fail("Some documents missing from one of the input datasets")
236+
237+
for i in range(len(doc_hashes)):
238+
for j in range(i + 1, len(doc_hashes)):
239+
diff = set(doc_hashes[i]) ^ set(doc_hashes[j])
240+
241+
if diff:
242+
msg.fail(
243+
f"Docs {diff} unequal between dataset {i} and {j}", exits=1
244+
)
245+
246+
# Check that the tokens between the splitting and parsing docs match
247+
248+
for hash in doc_hashes[0]:
249+
250+
hash_matches = list(map(lambda x: get_doc_by_hash(x, hash), annotated_data))
251+
tokens = list(map(get_tokens, hash_matches))
252+
253+
if not check_all_equal(tokens):
254+
msg.fail(f"Token mismatch for document {hash}", exits=1)
255+
256+
return True
257+
258+
259+
def sort_docs_list(lst):
260+
"""Sort a list of prodigy docs by input hash
261+
"""
262+
return sorted(lst, key=lambda k: k["_input_hash"])
263+
264+
265+
def combine_token_label_pairs(pairs):
266+
"""Combines a list of [(token, label), (token, label)] to give
267+
(token,label,label).
268+
"""
269+
return pairs[0][0:] + tuple(pair[1] for pair in pairs[1:])
270+
271+
186272
@plac.annotations(
187-
input_file=("Path to jsonl file containing prodigy docs.", "positional", None, str),
273+
input_files=(
274+
"Comma separated list of paths to jsonl files containing prodigy docs.",
275+
"positional",
276+
None,
277+
str,
278+
),
188279
output_file=("Path to output tsv file.", "positional", None, str),
189280
respect_lines=(
190281
"Respect line endings? Or parse entire document in a single string?",
@@ -201,32 +292,99 @@ def yield_token_label_pair(self, doc, lists=False):
201292
line_limit=("Number of characters to include on a line", "option", "l", int),
202293
)
203294
def prodigy_to_tsv(
204-
input_file, output_file, respect_lines, respect_docs, line_limit=250
295+
input_files, output_file, respect_lines, respect_docs, line_limit=250
205296
):
206297
"""
207298
Convert token annotated jsonl to token annotated tsv ready for use in the
208-
Rodrigues model.
299+
deep_reference_parser model.
300+
301+
Will combine annotations from two jsonl files containing the same docs and
302+
the same tokens by comparing the "_input_hash" and token texts. If they are
303+
compatible, the output file will contain both labels ready for use in a
304+
multi-task model, for example:
305+
306+
token label label
307+
------------ ----- -----
308+
References o o
309+
o o
310+
1 o o
311+
. o o
312+
o o
313+
WHO title b-r
314+
treatment title i-r
315+
guidelines title i-r
316+
for title i-r
317+
drug title i-r
318+
- title i-r
319+
resistant title i-r
320+
tuberculosis title i-r
321+
, title i-r
322+
2016 title i-r
323+
324+
Multiple files must be passed as a comma separated list e.g.
325+
326+
python -m deep_reference_parser.prodigy prodigy_to_tsv file1.jsonl,file2.jsonl out.tsv
327+
209328
"""
210329

330+
input_files = input_files.split(",")
331+
332+
msg.info(f"Loading annotations from {len(input_files)} datasets")
211333
msg.info(f"Respect line endings: {respect_lines}")
212334
msg.info(f"Respect doc endings: {respect_docs}")
213335
msg.info(f"Line limit: {line_limit}")
214336

215-
annotated_data = read_jsonl(input_file)
337+
# Read the input_files. Note the use of map here, because we don't know
338+
# how many sets of annotations area being passed in the list. It could be 2
339+
# but in future it may be more.
340+
341+
annotated_data = list(map(read_jsonl, input_files))
342+
343+
# Check that the tokens match between sets of annotations. If not raise
344+
# errors and stop.
345+
346+
check_inputs(annotated_data)
216347

217-
logger.info("Loaded %s prodigy docs", len(annotated_data))
348+
# Sort the docs so that they are in the same order before converting to
349+
# token label pairs.
350+
351+
annotated_data = list(map(sort_docs_list, annotated_data))
218352

219353
tlp = TokenLabelPairs(
220354
respect_doc_endings=respect_docs,
221355
respect_line_endings=respect_lines,
222356
line_limit=line_limit,
223357
)
224-
token_label_pairs = list(tlp.run(annotated_data))
358+
359+
pairs_list = list(map(tlp.run, annotated_data))
360+
361+
# NOTE: Use of reduce to handle pairs_list of unknown length
362+
363+
if len(pairs_list) > 1:
364+
merged_pairs = (
365+
combine_token_label_pairs(pairs) for pairs in reduce(zip, pairs_list)
366+
)
367+
example_pairs = [
368+
combine_token_label_pairs(pairs)
369+
for i, pairs in enumerate(reduce(zip, pairs_list))
370+
if i < ROWS_TO_PRINT
371+
]
372+
else:
373+
merged_pairs = pairs_list[0]
374+
example_pairs = merged_pairs[0:ROWS_TO_PRINT]
225375

226376
with open(output_file, "w") as fb:
227377
writer = csv.writer(fb, delimiter="\t")
228378
# Write DOCSTART and a blank line
229-
writer.writerows([("DOCSTART", None), (None, None)])
230-
writer.writerows(token_label_pairs)
379+
# writer.writerows([("DOCSTART", None), (None, None)])
380+
writer.writerows(merged_pairs)
381+
382+
# Print out the first ten rows as a sense check
383+
384+
msg.divider("Example output")
385+
header = ["token"] + ["label"] * len(annotated_data)
386+
aligns = ["r"] + ["l"] * len(annotated_data)
387+
formatted = table(example_pairs, header=header, divider=True, aligns=aligns)
388+
print(formatted)
231389

232-
logger.info("Wrote %s token/label pairs to %s", len(token_label_pairs), output_file)
390+
msg.good(f"Wrote token/label pairs to {output_file}")

tests/prodigy/test_prodigy_entrypoints.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ def test_prodigy_to_tsv(tmpdir):
3636
)
3737

3838

39+
def test_prodigy_to_tsv_multiple_inputs(tmpdir):
40+
prodigy_to_tsv(
41+
TEST_TOKEN_LABELLED + "," + TEST_TOKEN_LABELLED,
42+
os.path.join(tmpdir, "tokens.tsv"),
43+
respect_lines=False,
44+
respect_docs=True,
45+
)
46+
47+
3948
def test_reach_to_prodigy(tmpdir):
4049
reach_to_prodigy(TEST_REACH, os.path.join(tmpdir, "prodigy.jsonl"))
4150

tests/prodigy/test_prodigy_to_tsv.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import pytest
88

99
from deep_reference_parser.io import read_jsonl
10-
from deep_reference_parser.prodigy.prodigy_to_tsv import TokenLabelPairs, prodigy_to_tsv
10+
from deep_reference_parser.prodigy.prodigy_to_tsv import (
11+
TokenLabelPairs,
12+
prodigy_to_tsv,
13+
check_inputs,
14+
)
1115

1216
from .common import TEST_SPANS, TEST_TOKENS
1317

@@ -738,6 +742,44 @@ def test_reference_spans_real_example(doc):
738742
tlp = TokenLabelPairs(respect_line_endings=False)
739743
actual = tlp.run([doc])
740744

741-
import pprint
742-
743745
assert actual == expected
746+
747+
748+
def test_check_input_exist_on_doc_mismatch():
749+
750+
dataset_a = [{"_input_hash": "a1"}, {"_input_hash": "a2"}]
751+
dataset_b = [{"_input_hash": "b1"}, {"_input_hash": "b2"}]
752+
753+
with pytest.raises(SystemExit):
754+
check_inputs([dataset_a, dataset_b])
755+
756+
757+
def test_check_input_exist_on_tokens_mismatch():
758+
759+
dataset_a = [
760+
{"_input_hash": "a", "tokens": [{"text": "a"}]},
761+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
762+
]
763+
764+
dataset_b = [
765+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
766+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
767+
]
768+
769+
with pytest.raises(SystemExit):
770+
check_inputs([dataset_a, dataset_b])
771+
772+
773+
def test_check_input():
774+
775+
dataset_a = [
776+
{"_input_hash": "a", "tokens": [{"text": "a"}]},
777+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
778+
]
779+
780+
dataset_b = [
781+
{"_input_hash": "a", "tokens": [{"text": "a"}]},
782+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
783+
]
784+
785+
assert check_inputs([dataset_a, dataset_b])

0 commit comments

Comments
 (0)