1010from .deep_reference_parser import logger
1111
1212
13- def tokens_to_references (tokens , labels ):
14- """
15- Given a corresponding list of tokens and a list of labels: split the tokens
16- and return a list of references.
17-
18- Args:
19- tokens(list): A list of tokens.
20- labels(list): A corresponding list of labels.
21-
22- """
13+ def get_reference_spans (tokens , spans ):
2314
2415 # Flatten the lists of tokens and predictions into a single list.
2516
2617 flat_tokens = list (itertools .chain .from_iterable (tokens ))
27- flat_predictions = list (itertools .chain .from_iterable (labels ))
18+ flat_predictions = list (itertools .chain .from_iterable (spans ))
2819
2920 # Find all b-r and e-r tokens.
3021
@@ -37,30 +28,43 @@ def tokens_to_references(tokens, labels):
3728 logger .debug ("Found %s b-r tokens" , len (ref_starts ))
3829 logger .debug ("Found %s e-r tokens" , len (ref_ends ))
3930
40- references = []
41-
4231 n_refs = len (ref_starts )
4332
4433 # Split on each b-r.
45- # TODO: It may be worth including some simple post processing steps here
46- # to pick up false positives, for instance cutting short a reference
47- # after n tokens.
4834
35+ token_starts = []
36+ token_ends = []
4937 for i in range (0 , n_refs ):
50- token_start = ref_starts [i ]
38+ token_starts . append ( ref_starts [i ])
5139 if i + 1 < n_refs :
52-
53- token_end = ref_starts [i + 1 ] - 1
40+ token_ends .append (ref_starts [i + 1 ] - 1 )
5441 else :
55- token_end = len (flat_tokens )
42+ token_ends . append ( len (flat_tokens ) )
5643
44+ return token_starts , token_ends , flat_tokens
45+
46+
47+ def tokens_to_references (tokens , labels ):
48+ """
49+ Given a corresponding list of tokens and a list of labels: split the tokens
50+ and return a list of references.
51+
52+ Args:
53+ tokens(list): A list of tokens.
54+ labels(list): A corresponding list of labels.
55+
56+ """
57+
58+ token_starts , token_ends , flat_tokens = get_reference_spans (tokens , labels )
59+
60+ references = []
61+ for token_start , token_end in zip (token_starts , token_ends ):
5762 ref = flat_tokens [token_start : token_end + 1 ]
5863 flat_ref = " " .join (ref )
5964 references .append (flat_ref )
6065
6166 return references
6267
63-
6468def tokens_to_reference_lists (tokens , spans , components ):
6569 """
6670 Given a corresponding list of tokens, a list of
@@ -75,37 +79,12 @@ def tokens_to_reference_lists(tokens, spans, components):
7579
7680 """
7781
78- # Flatten the lists of tokens and predictions into a single list.
79-
80- flat_tokens = list (itertools .chain .from_iterable (tokens ))
81- flat_spans = list (itertools .chain .from_iterable (spans ))
82+ token_starts , token_ends , flat_tokens = get_reference_spans (tokens , spans )
8283 flat_components = list (itertools .chain .from_iterable (components ))
8384
84- # Find all b-r and e-r tokens.
85-
86- ref_starts = [
87- index for index , label in enumerate (flat_spans ) if label == "b-r"
88- ]
89-
90- ref_ends = [index for index , label in enumerate (flat_spans ) if label == "e-r" ]
91-
92- logger .debug ("Found %s b-r tokens" , len (ref_starts ))
93- logger .debug ("Found %s e-r tokens" , len (ref_ends ))
94-
9585 references_components = []
86+ for token_start , token_end in zip (token_starts , token_ends ):
9687
97- n_refs = len (ref_starts )
98-
99- # Split on each b-r.
100-
101- for i in range (0 , n_refs ):
102- token_start = ref_starts [i ]
103- if i + 1 < n_refs :
104-
105- token_end = ref_starts [i + 1 ] - 1
106- else :
107- token_end = len (flat_tokens )
108-
10988 ref_tokens = flat_tokens [token_start : token_end + 1 ]
11089 ref_components = flat_components [token_start : token_end + 1 ]
11190 flat_ref = " " .join (ref_tokens )
0 commit comments