Skip to content

Commit e4de5a4

Browse files
committed
refactoring tokens_to_references
1 parent 2188edb commit e4de5a4

File tree

1 file changed

+27
-48
lines changed

1 file changed

+27
-48
lines changed

deep_reference_parser/tokens_to_references.py

Lines changed: 27 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,12 @@
1010
from .deep_reference_parser import logger
1111

1212

13-
def tokens_to_references(tokens, labels):
14-
"""
15-
Given a corresponding list of tokens and a list of labels: split the tokens
16-
and return a list of references.
17-
18-
Args:
19-
tokens(list): A list of tokens.
20-
labels(list): A corresponding list of labels.
21-
22-
"""
13+
def get_reference_spans(tokens, spans):
2314

2415
# Flatten the lists of tokens and predictions into a single list.
2516

2617
flat_tokens = list(itertools.chain.from_iterable(tokens))
27-
flat_predictions = list(itertools.chain.from_iterable(labels))
18+
flat_predictions = list(itertools.chain.from_iterable(spans))
2819

2920
# Find all b-r and e-r tokens.
3021

@@ -37,30 +28,43 @@ def tokens_to_references(tokens, labels):
3728
logger.debug("Found %s b-r tokens", len(ref_starts))
3829
logger.debug("Found %s e-r tokens", len(ref_ends))
3930

40-
references = []
41-
4231
n_refs = len(ref_starts)
4332

4433
# Split on each b-r.
45-
# TODO: It may be worth including some simple post processing steps here
46-
# to pick up false positives, for instance cutting short a reference
47-
# after n tokens.
4834

35+
token_starts = []
36+
token_ends = []
4937
for i in range(0, n_refs):
50-
token_start = ref_starts[i]
38+
token_starts.append(ref_starts[i])
5139
if i + 1 < n_refs:
52-
53-
token_end = ref_starts[i + 1] - 1
40+
token_ends.append(ref_starts[i + 1] - 1)
5441
else:
55-
token_end = len(flat_tokens)
42+
token_ends.append(len(flat_tokens))
5643

44+
return token_starts, token_ends, flat_tokens
45+
46+
47+
def tokens_to_references(tokens, labels):
48+
"""
49+
Given a corresponding list of tokens and a list of labels: split the tokens
50+
and return a list of references.
51+
52+
Args:
53+
tokens(list): A list of tokens.
54+
labels(list): A corresponding list of labels.
55+
56+
"""
57+
58+
token_starts, token_ends, flat_tokens = get_reference_spans(tokens, labels)
59+
60+
references = []
61+
for token_start, token_end in zip(token_starts, token_ends):
5762
ref = flat_tokens[token_start : token_end + 1]
5863
flat_ref = " ".join(ref)
5964
references.append(flat_ref)
6065

6166
return references
6267

63-
6468
def tokens_to_reference_lists(tokens, spans, components):
6569
"""
6670
Given a corresponding list of tokens, a list of
@@ -75,37 +79,12 @@ def tokens_to_reference_lists(tokens, spans, components):
7579
7680
"""
7781

78-
# Flatten the lists of tokens and predictions into a single list.
79-
80-
flat_tokens = list(itertools.chain.from_iterable(tokens))
81-
flat_spans = list(itertools.chain.from_iterable(spans))
82+
token_starts, token_ends, flat_tokens = get_reference_spans(tokens, spans)
8283
flat_components = list(itertools.chain.from_iterable(components))
8384

84-
# Find all b-r and e-r tokens.
85-
86-
ref_starts = [
87-
index for index, label in enumerate(flat_spans) if label == "b-r"
88-
]
89-
90-
ref_ends = [index for index, label in enumerate(flat_spans) if label == "e-r"]
91-
92-
logger.debug("Found %s b-r tokens", len(ref_starts))
93-
logger.debug("Found %s e-r tokens", len(ref_ends))
94-
9585
references_components = []
86+
for token_start, token_end in zip(token_starts, token_ends):
9687

97-
n_refs = len(ref_starts)
98-
99-
# Split on each b-r.
100-
101-
for i in range(0, n_refs):
102-
token_start = ref_starts[i]
103-
if i + 1 < n_refs:
104-
105-
token_end = ref_starts[i + 1] - 1
106-
else:
107-
token_end = len(flat_tokens)
108-
10988
ref_tokens = flat_tokens[token_start : token_end + 1]
11089
ref_components = flat_components[token_start : token_end + 1]
11190
flat_ref = " ".join(ref_tokens)

0 commit comments

Comments
 (0)