Skip to content

Commit 55bc5e2

Browse files
chg: linting with 🏴💋
1 parent 136df23 commit 55bc5e2

File tree

3 files changed

+46
-15
lines changed

3 files changed

+46
-15
lines changed

deep_reference_parser/prodigy/prodigy_to_tsv.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,30 +183,36 @@ def yield_token_label_pair(self, doc, lists=False):
183183

184184
token_counter += 1
185185

186+
186187
def get_document_hashes(dataset):
187188
"""Get the hashes for every doc in a dataset and return as set
188189
"""
189190
return set([doc["_input_hash"] for doc in dataset])
190191

192+
191193
def check_all_equal(lst):
192194
"""Check that all items in a list are equal and return True or False
193195
"""
194196
return not lst or lst.count(lst[0]) == len(lst)
195197

198+
196199
def hash_matches(doc, hash):
197200
"""Check whether the hash of the passed doc matches the passed hash
198201
"""
199202
return doc["_input_hash"] == hash
200203

204+
201205
def get_doc_by_hash(dataset, hash):
202206
"""Return a doc from a dataset where hash matches doc["_input_hash"]
203207
Assumes there will only be one match!
204208
"""
205209
return [doc for doc in dataset if doc["_input_hash"] == hash][0]
206210

211+
207212
def get_tokens(doc):
208213
return [token["text"] for token in doc["tokens"]]
209214

215+
210216
def check_inputs(annotated_data):
211217
"""Checks whether two prodigy datasets contain the same docs (evaluated by
212218
doc["_input_hash"] and whether those docs contain the same tokens. This is
@@ -231,7 +237,9 @@ def check_inputs(annotated_data):
231237
diff = set(doc_hashes[i]) ^ set(doc_hashes[j])
232238

233239
if diff:
234-
msg.fail(f"Docs {diff} unequal between dataset {i} and {j}", exits=1)
240+
msg.fail(
241+
f"Docs {diff} unequal between dataset {i} and {j}", exits=1
242+
)
235243

236244
# Check that the tokens between the splitting and parsing docs match
237245

@@ -245,19 +253,27 @@ def check_inputs(annotated_data):
245253

246254
return True
247255

256+
248257
def sort_docs_list(lst):
249258
"""Sort a list of prodigy docs by input hash
250259
"""
251-
return sorted(lst, key=lambda k: k['_input_hash'])
260+
return sorted(lst, key=lambda k: k["_input_hash"])
261+
252262

253263
def combine_token_label_pairs(pairs):
254264
"""Combines a list of [(token, label), (token, label)] to give
255265
(token,label,label).
256266
"""
257267
return pairs[0][0:] + tuple(pair[1] for pair in pairs[1:])
258268

269+
259270
@plac.annotations(
260-
input_files=("Comma separated list of paths to jsonl files containing prodigy docs.", "positional", None, str),
271+
input_files=(
272+
"Comma separated list of paths to jsonl files containing prodigy docs.",
273+
"positional",
274+
None,
275+
str,
276+
),
261277
output_file=("Path to output tsv file.", "positional", None, str),
262278
respect_lines=(
263279
"Respect line endings? Or parse entire document in a single string?",
@@ -343,16 +359,22 @@ def prodigy_to_tsv(
343359
# NOTE: Use of reduce to handle pairs_list of unknown length
344360

345361
if len(pairs_list) > 1:
346-
merged_pairs = (combine_token_label_pairs(pairs) for pairs in reduce(zip, pairs_list))
347-
example_pairs = [combine_token_label_pairs(pairs) for i, pairs in enumerate(reduce(zip, pairs_list)) if i < 15]
362+
merged_pairs = (
363+
combine_token_label_pairs(pairs) for pairs in reduce(zip, pairs_list)
364+
)
365+
example_pairs = [
366+
combine_token_label_pairs(pairs)
367+
for i, pairs in enumerate(reduce(zip, pairs_list))
368+
if i < 15
369+
]
348370
else:
349371
merged_pairs = pairs_list[0]
350372
example_pairs = merged_pairs[0:14]
351373

352374
with open(output_file, "w") as fb:
353375
writer = csv.writer(fb, delimiter="\t")
354376
# Write DOCSTART and a blank line
355-
#writer.writerows([("DOCSTART", None), (None, None)])
377+
# writer.writerows([("DOCSTART", None), (None, None)])
356378
writer.writerows(merged_pairs)
357379

358380
# Print out the first ten rows as a sense check

tests/prodigy/test_prodigy_entrypoints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_prodigy_to_tsv(tmpdir):
3535
respect_docs=True,
3636
)
3737

38+
3839
def test_prodigy_to_tsv_multiple_inputs(tmpdir):
3940
prodigy_to_tsv(
4041
TEST_TOKEN_LABELLED + "," + TEST_TOKEN_LABELLED,
@@ -43,6 +44,7 @@ def test_prodigy_to_tsv_multiple_inputs(tmpdir):
4344
respect_docs=True,
4445
)
4546

47+
4648
def test_reach_to_prodigy(tmpdir):
4749
reach_to_prodigy(TEST_REACH, os.path.join(tmpdir, "prodigy.jsonl"))
4850

tests/prodigy/test_prodigy_to_tsv.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import pytest
88

99
from deep_reference_parser.io import read_jsonl
10-
from deep_reference_parser.prodigy.prodigy_to_tsv import TokenLabelPairs, prodigy_to_tsv, check_inputs
10+
from deep_reference_parser.prodigy.prodigy_to_tsv import (
11+
TokenLabelPairs,
12+
prodigy_to_tsv,
13+
check_inputs,
14+
)
1115

1216
from .common import TEST_SPANS, TEST_TOKENS
1317

@@ -740,6 +744,7 @@ def test_reference_spans_real_example(doc):
740744

741745
assert actual == expected
742746

747+
743748
def test_check_input_exist_on_doc_mismatch():
744749

745750
dataset_a = [{"_input_hash": "a1"}, {"_input_hash": "a2"}]
@@ -748,31 +753,33 @@ def test_check_input_exist_on_doc_mismatch():
748753
with pytest.raises(SystemExit):
749754
check_inputs([dataset_a, dataset_b])
750755

756+
751757
def test_check_input_exist_on_tokens_mismatch():
752758

753759
dataset_a = [
754-
{"_input_hash": "a", "tokens": [{"text":"a"}]},
755-
{"_input_hash": "a", "tokens": [{"text":"b"}]},
760+
{"_input_hash": "a", "tokens": [{"text": "a"}]},
761+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
756762
]
757763

758764
dataset_b = [
759-
{"_input_hash": "a", "tokens": [{"text":"b"}]},
760-
{"_input_hash": "a", "tokens": [{"text":"b"}]},
765+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
766+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
761767
]
762768

763769
with pytest.raises(SystemExit):
764770
check_inputs([dataset_a, dataset_b])
765771

772+
766773
def test_check_input():
767774

768775
dataset_a = [
769-
{"_input_hash": "a", "tokens": [{"text":"a"}]},
770-
{"_input_hash": "a", "tokens": [{"text":"b"}]},
776+
{"_input_hash": "a", "tokens": [{"text": "a"}]},
777+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
771778
]
772779

773780
dataset_b = [
774-
{"_input_hash": "a", "tokens": [{"text":"a"}]},
775-
{"_input_hash": "a", "tokens": [{"text":"b"}]},
781+
{"_input_hash": "a", "tokens": [{"text": "a"}]},
782+
{"_input_hash": "a", "tokens": [{"text": "b"}]},
776783
]
777784

778785
assert check_inputs([dataset_a, dataset_b])

0 commit comments

Comments
 (0)