@@ -323,13 +323,16 @@ def _search_step(self, state):
323
323
new state dictionary.
324
324
"""
325
325
# Grow alive sequences by one token.
326
- new_seq , new_log_probs , new_cache = self ._grow_alive_seq (state )
326
+ new_seq , new_log_probs , topk_ids , new_cache = self ._grow_alive_seq (state )
327
+ new_finished_flags = tf .equal (topk_ids , self .eos_id )
327
328
# Collect top beam_size alive sequences
328
- alive_state = self ._get_new_alive_state (new_seq , new_log_probs , new_cache )
329
+ alive_state = self ._get_new_alive_state (new_seq , new_log_probs ,
330
+ new_finished_flags , new_cache )
329
331
330
332
# Combine newly finished sequences with existing finished sequences, and
331
333
# collect the top k scoring sequences.
332
- finished_state = self ._get_new_finished_state (state , new_seq , new_log_probs )
334
+ finished_state = self ._get_new_finished_state (state , new_seq , new_log_probs ,
335
+ new_finished_flags )
333
336
334
337
# Increment loop index and create new state dictionary
335
338
new_state = {_StateKeys .CUR_INDEX : state [_StateKeys .CUR_INDEX ] + 1 }
@@ -407,18 +410,20 @@ def _grow_alive_seq(self, state):
407
410
tf .expand_dims (topk_ids , axis = 0 ))
408
411
topk_seq = tf .transpose (topk_seq , perm = [1 , 2 , 0 ])
409
412
else :
410
- topk_ids = tf .expand_dims (topk_ids , axis = 2 )
411
- topk_seq = tf .concat ([topk_seq , topk_ids ], axis = 2 )
412
- return topk_seq , topk_log_probs , new_cache
413
+ topk_seq = tf .concat ([topk_seq , tf .expand_dims (topk_ids , axis = 2 )], axis = 2 )
414
+ return topk_seq , topk_log_probs , topk_ids , new_cache
413
415
414
- def _get_new_alive_state (self , new_seq , new_log_probs , new_cache ):
416
+ def _get_new_alive_state (self , new_seq , new_log_probs , new_finished_flags ,
417
+ new_cache ):
415
418
"""Gather the top k sequences that are still alive.
416
419
417
420
Args:
418
421
new_seq: New sequences generated by growing the current alive sequences
419
422
int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
420
- new_log_probs: Log probabilities of new sequences
421
- float32 tensor with shape [batch_size, beam_size]
423
+ new_log_probs: Log probabilities of new sequences float32 tensor with
424
+ shape [batch_size, beam_size]
425
+ new_finished_flags: A boolean Tensor indicates which sequences are live
426
+ inside the beam.
422
427
new_cache: Dict of cached values for each sequence.
423
428
424
429
Returns:
@@ -428,7 +433,6 @@ def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
428
433
Dict cache storing decoder states for top alive sequences}
429
434
"""
430
435
# To prevent finished sequences from being considered, set log probs to -inf
431
- new_finished_flags = tf .equal (new_seq [:, :, - 1 ], self .eos_id )
432
436
new_log_probs += tf .cast (new_finished_flags , self .dtype ) * - inf (self .dtype )
433
437
434
438
top_alive_seq , top_alive_log_probs , top_alive_cache = _gather_topk_beams (
@@ -441,15 +445,18 @@ def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
441
445
_StateKeys .ALIVE_CACHE : top_alive_cache
442
446
}
443
447
444
- def _get_new_finished_state (self , state , new_seq , new_log_probs ):
448
+ def _get_new_finished_state (self , state , new_seq , new_log_probs ,
449
+ new_finished_flags ):
445
450
"""Combine new and old finished sequences, and gather the top k sequences.
446
451
447
452
Args:
448
453
state: A dictionary with the current loop state.
449
454
new_seq: New sequences generated by growing the current alive sequences
450
455
int32 tensor with shape [batch_size, beam_size, i + 1]
451
- new_log_probs: Log probabilities of new sequences
452
- float32 tensor with shape [batch_size, beam_size]
456
+ new_log_probs: Log probabilities of new sequences float32 tensor with
457
+ shape [batch_size, beam_size]
458
+ new_finished_flags: A boolean Tensor indicates which sequences are live
459
+ inside the beam.
453
460
454
461
Returns:
455
462
Dictionary with finished keys from _StateKeys:
@@ -476,7 +483,6 @@ def _get_new_finished_state(self, state, new_seq, new_log_probs):
476
483
new_scores = new_log_probs / length_norm
477
484
478
485
# Set the scores of the still-alive seq in new_seq to large negative values.
479
- new_finished_flags = tf .equal (new_seq [:, :, - 1 ], self .eos_id )
480
486
new_scores += ((1. - tf .cast (new_finished_flags , self .dtype )) *
481
487
- inf (self .dtype ))
482
488
0 commit comments