2929
3030
3131def compute_topk_scores_and_seq (sequences , scores , scores_to_gather , flags ,
32- beam_dim , prefix = "default" ,
33- states = None ):
32+ beam_dim , prefix = "default" ):
3433 """Given sequences and scores, will gather the top k=beam size sequences.
3534
3635 This function is used to grow alive, and finished. It takes sequences,
3736 scores, and flags, and returns the top k from sequences, scores_to_gather,
3837 and flags based on the values in scores.
3938
40- This method permits easy introspection using tfdbg. It adds three named ops
39+ This method permits easy introspection using tfdbg. It adds two named ops
4140 that are prefixed by `prefix`:
4241 - _topk_seq: the tensor for topk_seq returned by this method.
4342 - _topk_flags: the tensor for topk_finished_flags returned by this method.
44- - _topk_scores: the tensor for tokp_gathered_scores returned by this method.
4543
4644 Args:
4745 sequences: Tensor of sequences that we need to gather from.
@@ -57,17 +55,18 @@ def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags,
5755 EOS or not
5856 beam_dim: mtf.Dimension
5957 prefix: an optional string
60- states: an optional list of mtf.Tensor
6158 Returns:
6259 Tuple of
6360 (topk_seq [batch_size, beam_size, decode_length],
6461 topk_gathered_scores [batch_size, beam_size],
6562 topk_finished_flags[batch_size, beam_size],
66- topk_gathered_states )
63+ selector )
6764 """
6865 unused_batch_dim , old_beam_dim , unused_length_dim = sequences .shape .dims
6966 topk_indices , _ = mtf .top_k (scores , old_beam_dim , beam_dim )
7067
68+ selector = mtf .one_hot (topk_indices , old_beam_dim , dtype = tf .float32 )
69+
7170 # Gather up the highest scoring sequences.
7271 # For each operation added, give it
7372 # a concrete name to simplify observing these operations with tfdbg.
@@ -81,11 +80,7 @@ def gather(tensor, name):
8180 topk_seq = gather (sequences , "_seq" )
8281 topk_flags = gather (flags , "_flags" )
8382 topk_gathered_scores = gather (scores_to_gather , "_scores" )
84- if states is None :
85- topk_gathered_states = None
86- else :
87- topk_gathered_states = [gather (state , "_topk_states" ) for state in states ]
88- return topk_seq , topk_gathered_scores , topk_flags , topk_gathered_states
83+ return topk_seq , topk_gathered_scores , topk_flags , selector
8984
9085
9186def beam_search (logits_fn ,
@@ -213,9 +208,9 @@ def _my_concat(a, b):
213208 curr_finished_flags = _my_concat (finished_flags , curr_finished )
214209 return compute_topk_scores_and_seq (
215210 curr_finished_seq , curr_finished_scores , curr_finished_scores ,
216- curr_finished_flags , beam_dim , "grow_finished" , states = None )
211+ curr_finished_flags , beam_dim , "grow_finished" )
217212
218- def grow_alive (curr_seq , curr_scores , curr_log_probs , curr_finished , states ):
213+ def grow_alive (curr_seq , curr_scores , curr_log_probs , curr_finished ):
219214 """Given sequences and scores, will gather the top k=beam size sequences.
220215
221216 Args:
@@ -226,7 +221,6 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states):
226221 [batch, beam]
227222 curr_finished: Finished flags for each of these sequences.
228223 [batch, beam]
229- states: list of mtf.Tensor
230224 Returns:
231225 Tuple of
232226 (Topk sequences based on scores,
@@ -238,7 +232,7 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states):
238232 curr_scores += mtf .cast (curr_finished , curr_scores .dtype ) * - INF
239233 return compute_topk_scores_and_seq (curr_seq , curr_scores , curr_log_probs ,
240234 curr_finished , beam_dim ,
241- "grow_alive" , states )
235+ "grow_alive" )
242236
243237 def grow_topk (i , alive_seq , alive_log_probs , states = None ):
244238 r"""Inner beam search loop.
@@ -298,6 +292,8 @@ def grow_topk(i, alive_seq, alive_log_probs, states=None):
298292 top_beam_index = top_ids // vocab_dim .size
299293 top_ids %= vocab_dim .size # Unflatten the ids
300294
295+ selector = mtf .one_hot (top_beam_index , beam_dim , dtype = tf .float32 )
296+
301297 def my_gather (tensor ):
302298 return mtf .gather (
303299 tensor , top_beam_index , beam_dim ,
@@ -308,14 +304,12 @@ def my_gather(tensor):
308304 # bools
309305 top_seq = my_gather (alive_seq )
310306
311- if states :
312- states = [my_gather (state ) for state in new_states ]
313-
314307 # Append the most probable alive
315308 top_seq += top_ids * mtf .one_hot (i , length_dim , dtype = tf .int32 )
316309 top_finished = mtf .equal (top_ids , eos_id )
317310
318- return top_seq , top_log_probs , top_scores , top_finished , states
311+ return (
312+ top_seq , top_log_probs , top_scores , top_finished , new_states , selector )
319313
320314 def inner_loop (i , alive_seq , alive_log_probs , finished_seq , finished_scores ,
321315 finished_flags , * states ):
@@ -368,14 +362,26 @@ def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores,
368362 # 2. Extract the ones that have finished and haven't finished
369363 # 3. Recompute the contents of finished based on scores.
370364 (top2k_seq , top2k_log_probs , top2k_scores , top2k_finished ,
371- top2k_states ) = grow_topk (i , alive_seq , alive_log_probs , states )
372- alive_seq , alive_log_probs , _ , states = grow_alive (
373- top2k_seq , top2k_scores , top2k_log_probs , top2k_finished , top2k_states )
365+ new_states , first_selector ) = grow_topk (
366+ i , alive_seq , alive_log_probs , states )
367+ alive_seq , alive_log_probs , _ , second_selector = grow_alive (
368+ top2k_seq , top2k_scores , top2k_log_probs , top2k_finished )
374369 finished_seq , finished_scores , finished_flags , _ = grow_finished (
375370 finished_seq , finished_scores , finished_flags , top2k_seq , top2k_scores ,
376371 top2k_finished )
372+ old_beam_dim = mtf .Dimension ("old_beam" , beam_dim .size )
373+ selector = mtf .einsum (
374+ [mtf .rename_dimension (first_selector , beam_dim .name , old_beam_dim .name ),
375+ second_selector ],
376+ output_shape = [batch_dim , old_beam_dim , beam_dim ])
377+ new_states = [
378+ mtf .einsum (
379+ [mtf .rename_dimension (state , beam_dim .name , old_beam_dim .name ),
380+ mtf .cast (selector , state .dtype )],
381+ reduced_dims = [old_beam_dim ], output_shape = state .shape )
382+ for state in new_states ]
377383 return (i + 1 , alive_seq , alive_log_probs , finished_seq , finished_scores ,
378- finished_flags ) + tuple (states )
384+ finished_flags ) + tuple (new_states )
379385
380386 def _is_finished (i , unused_alive_seq , alive_log_probs , unused_finished_seq ,
381387 finished_scores , finished_in_finished , * unused_states ):
0 commit comments