Skip to content

Commit cdc4cad

Browse files
poorva87tensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 338521847
1 parent 4f50e2f commit cdc4cad

File tree

4 files changed

+832
-0
lines changed

4 files changed

+832
-0
lines changed
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
16+
17+
import abc
18+
from typing import Any, Callable, Dict, Tuple
19+
20+
import tensorflow as tf
21+
22+
from tensorflow.python.framework import dtypes
23+
24+
Output = Tuple[tf.Tensor, tf.Tensor]
25+
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
26+
InitialState = Tuple[Dict[str, Any], Dict[str, Any]]
27+
28+
29+
class StateKeys:
30+
"""Keys to dictionary storing the state of Decoding loop."""
31+
32+
# Variable storing the loop index.
33+
CUR_INDEX = "CUR_INDEX"
34+
35+
# Top sequences that are alive for each batch item. Alive sequences are ones
36+
# that have not generated an EOS token. Sequences that reach EOS are marked as
37+
# finished and moved to the FINISHED_SEQ tensor.
38+
# Has shape [batch_size, beam_size, CUR_INDEX + 1] for SequenceBeamSearch and
39+
# [batch_size, CUR_INDEX + 1] otherwise.
40+
ALIVE_SEQ = "ALIVE_SEQ"
41+
# Log probabilities of each alive sequence. Shape [batch_size, beam_size]
42+
ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS"
43+
# Dictionary of cached values for each alive sequence. The cache stores
44+
# the encoder output, attention bias, and the decoder attention output from
45+
# the previous iteration.
46+
ALIVE_CACHE = "ALIVE_CACHE"
47+
48+
# Top finished sequences for each batch item.
49+
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
50+
# shorter than CUR_INDEX + 1 are padded with 0s.
51+
FINISHED_SEQ = "FINISHED_SEQ"
52+
# Scores for each finished sequence. Score = log probability / length norm
53+
# Shape [batch_size, beam_size]
54+
FINISHED_SCORES = "FINISHED_SCORES"
55+
# Flags indicating which sequences in the finished sequences are finished.
56+
# At the beginning, all of the sequences in FINISHED_SEQ are filler values.
57+
# True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
58+
FINISHED_FLAGS = "FINISHED_FLAGS"
59+
60+
61+
class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
62+
"""A base class for the API required for decoding (go/decoding-tf-nlp)."""
63+
64+
def __init__(self,
65+
length_normalization_fn: Callable[[int, tf.DType], float],
66+
dtype: tf.DType = tf.float32):
67+
"""Initialize the Decoding Module.
68+
69+
Args:
70+
length_normalization_fn: Closure for returning length normalization
71+
parameter. Function accepts input as length, dtype and returns float.
72+
dtype: A tensorflow data type used for score computation. The default is
73+
tf.float32.
74+
"""
75+
self.length_normalization_fn = length_normalization_fn
76+
self.dtype = tf.as_dtype(dtype)
77+
78+
def generate(self,
79+
initial_ids: tf.Tensor,
80+
initial_cache: Dict[str, tf.Tensor]) -> Output:
81+
"""Implements the decoding strategy (beam_search or sampling).
82+
83+
Args:
84+
initial_ids: initial ids to pass into the symbols_to_logits_fn.
85+
int tensor with shape [batch_size, 1]
86+
initial_cache: dictionary for caching model outputs from previous step.
87+
Returns:
88+
Tuple of tensors representing
89+
finished_sequence: shape [batch, max_seq_length]
90+
finished_scores: [batch]
91+
"""
92+
batch_size = (
93+
initial_ids.shape.as_list()[0]
94+
if self.padded_decode else tf.shape(initial_ids)[0])
95+
96+
state, state_shapes = self._create_initial_state(initial_ids,
97+
initial_cache,
98+
batch_size)
99+
100+
def _generate_step(state):
101+
topk_seq, topk_log_probs, topk_ids, new_cache = self._grow_alive_seq(
102+
state, batch_size)
103+
new_finished_flags = self._finished_flags(topk_ids, state)
104+
alive_state = self._get_new_alive_state(topk_seq,
105+
topk_log_probs,
106+
new_finished_flags,
107+
new_cache)
108+
finished_state = self._get_new_finished_state(state,
109+
topk_seq,
110+
topk_log_probs,
111+
new_finished_flags,
112+
batch_size)
113+
new_state = {
114+
StateKeys.CUR_INDEX: state[StateKeys.CUR_INDEX] + 1
115+
}
116+
new_state.update(alive_state)
117+
new_state.update(finished_state)
118+
return [new_state]
119+
120+
finished_state = tf.nest.map_structure(
121+
tf.stop_gradient,
122+
tf.while_loop(
123+
self._continue_search,
124+
_generate_step,
125+
loop_vars=[state],
126+
shape_invariants=[state_shapes],
127+
parallel_iterations=1))
128+
final_state = self._process_finished_state(finished_state[0])
129+
return final_state
130+
131+
@abc.abstractmethod
132+
def _create_initial_state(self,
133+
initial_ids: tf.Tensor,
134+
initial_cache: Dict[str, tf.Tensor],
135+
batch_size: int) -> InitialState:
136+
"""Return initial state dictionary and its shape invariants."""
137+
pass
138+
139+
@abc.abstractmethod
140+
def _grow_alive_seq(self,
141+
state: Dict[str, Any],
142+
batch_size: int) -> InternalState:
143+
"""Grow alive sequences by one token.
144+
145+
Args:
146+
state: A dictionary with the current loop state.
147+
batch_size: The given batch size
148+
149+
Returns:
150+
Tuple of
151+
(Top sequences,
152+
Scores of returned sequences,
153+
New ids,
154+
New alive cache)
155+
"""
156+
pass
157+
158+
@abc.abstractmethod
159+
def _get_new_alive_state(
160+
self,
161+
new_seq: tf.Tensor,
162+
new_log_probs: tf.Tensor,
163+
new_finished_flags: tf.Tensor,
164+
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
165+
"""Gather the sequences that are still alive.
166+
167+
Args:
168+
new_seq: New sequences generated by growing the current alive sequences
169+
int32 tensor with shape
170+
new_log_probs: Log probabilities of new sequences float32 tensor with
171+
shape
172+
new_finished_flags: A boolean Tensor indicates which sequences are live.
173+
new_cache: Dict of cached values for each sequence.
174+
175+
Returns:
176+
Dictionary with alive keys from StateKeys.
177+
"""
178+
pass
179+
180+
@abc.abstractmethod
181+
def _get_new_finished_state(self,
182+
state: Dict[str, Any],
183+
new_seq: tf.Tensor,
184+
new_log_probs: tf.Tensor,
185+
new_finished_flags: tf.Tensor,
186+
batch_size: int) -> Dict[str, tf.Tensor]:
187+
"""Combine new and old finished sequences.
188+
189+
Args:
190+
state: A dictionary with the current loop state.
191+
new_seq: New sequences generated by growing the current alive sequences
192+
int32 tensor.
193+
new_log_probs: Log probabilities of new sequences float32 tensor with
194+
shape.
195+
new_finished_flags: A boolean Tensor indicates which sequences are live.
196+
batch_size: The given batch size.
197+
198+
Returns:
199+
Dictionary with finished keys from StateKeys.
200+
"""
201+
pass
202+
203+
@abc.abstractmethod
204+
def _process_finished_state(self, finished_state: Dict[str, Any]) -> Output:
205+
"""Process the alive/finished state to return final sequences and scores."""
206+
pass
207+
208+
@abc.abstractmethod
209+
def _continue_search(self, state: Dict[str, Any]) -> tf.Tensor:
210+
"""Returns a bool tensor if the decoding loop should continue."""
211+
pass
212+
213+
@abc.abstractmethod
214+
def _finished_flags(self,
215+
topk_ids: tf.Tensor,
216+
state: Dict[str, Any]) -> tf.Tensor:
217+
"""Calculate the finished flags."""
218+
pass
219+
220+
def inf(self):
221+
"""Returns a value close to infinity, but is still finite in `dtype`.
222+
223+
This is useful to get a very large value that is still zero when multiplied
224+
by zero. The floating-point "Inf" value is NaN when multiplied by zero.
225+
226+
Returns:
227+
A very large value.
228+
"""
229+
if self.dtype == dtypes.float32 or self.dtype == dtypes.bfloat16:
230+
return 1e7
231+
elif self.dtype == dtypes.float16:
232+
return dtypes.float16.max
233+
else:
234+
raise AssertionError("Invalid dtype: %s" % self.dtype)
235+
236+
@staticmethod
237+
def _log_prob_from_logits(logits):
238+
return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True)
239+
240+
@staticmethod
241+
def _shape_list(tensor):
242+
"""Return a list of the tensor's shape, and ensure no None values in list."""
243+
# Get statically known shape (may contain None's for unknown dimensions)
244+
shape = tensor.get_shape().as_list()
245+
246+
# Ensure that the shape values are not None
247+
dynamic_shape = tf.shape(tensor)
248+
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
249+
if shape[i] is None:
250+
shape[i] = dynamic_shape[i]
251+
return shape
252+
253+
@staticmethod
254+
def _get_shape_keep_last_dim(tensor):
255+
shape_list_obj = DecodingModule._shape_list(tensor)
256+
for i in range(len(shape_list_obj) - 1):
257+
shape_list_obj[i] = None
258+
259+
if isinstance(shape_list_obj[-1], tf.Tensor):
260+
shape_list_obj[-1] = None
261+
return tf.TensorShape(shape_list_obj)
262+
263+
@staticmethod
264+
def _expand_to_same_rank(tensor, target):
265+
"""Expands a given tensor to target's rank to be broadcastable.
266+
267+
Args:
268+
tensor: input tensor to tile. Shape: [b, d1, ..., da]
269+
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
270+
271+
Returns:
272+
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target
273+
274+
Raises:
275+
ValueError, if the shape rank of rank tensor/target is None.
276+
"""
277+
if tensor.shape.rank is None:
278+
raise ValueError("Expect rank for tensor shape, but got None.")
279+
if target.shape.rank is None:
280+
raise ValueError("Expect rank for target shape, but got None.")
281+
282+
with tf.name_scope("expand_rank"):
283+
diff_rank = target.shape.rank - tensor.shape.rank
284+
for _ in range(diff_rank):
285+
tensor = tf.expand_dims(tensor, -1)
286+
return tensor
287+
288+
289+
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Test decoding utility methods."""
16+
17+
import abc
18+
import tensorflow as tf
19+
20+
from official.nlp.modeling.ops import decoding_module
21+
22+
23+
def length_normalization(length, dtype):
24+
"""Return length normalization factor."""
25+
return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)
26+
27+
28+
class TestSubclass(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
29+
30+
def __init__(self,
31+
length_normalization_fn=length_normalization,
32+
dtype=tf.float32):
33+
super(TestSubclass, self).__init__(
34+
length_normalization_fn=length_normalization, dtype=dtype)
35+
36+
def _create_initial_state(self, initial_ids, initial_cache, batch_size):
37+
pass
38+
39+
def _grow_alive_seq(self, state, batch_size):
40+
pass
41+
42+
def _process_finished_state(self, finished_state):
43+
pass
44+
45+
def _get_new_finished_state(self, state, new_seq, new_log_probs,
46+
new_finished_flags, batch_size):
47+
pass
48+
49+
def _finished_flags(self, topk_ids, state):
50+
pass
51+
52+
def _continue_search(self, state):
53+
pass
54+
55+
def _get_new_alive_state(self, new_seq, new_log_probs, new_finished_flags,
56+
new_cache):
57+
pass
58+
59+
60+
class DecodingModuleTest(tf.test.TestCase):
61+
62+
def test_get_shape_keep_last_dim(self):
63+
y = tf.constant(4.0)
64+
x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5])
65+
shape = decoding_module.DecodingModule._get_shape_keep_last_dim(x)
66+
self.assertAllEqual([None, None, None, 5], shape.as_list())
67+
68+
def test_shape_list(self):
69+
x = tf.ones([7, 1])
70+
shape = decoding_module.DecodingModule._shape_list(x)
71+
self.assertAllEqual([7, 1], shape)
72+
73+
def test_inf(self):
74+
d = TestSubclass()
75+
inf_value = d.inf()
76+
self.assertAllEqual(inf_value, tf.constant(10000000., tf.float32))
77+
78+
def test_length_normalization(self):
79+
d = TestSubclass()
80+
normalized_length = d.length_normalization_fn(32, tf.float32)
81+
self.assertAllEqual(normalized_length, tf.constant(1.0, tf.float32))
82+
83+
if __name__ == '__main__':
84+
tf.test.main()

0 commit comments

Comments
 (0)