Skip to content

Commit 913640d

Browse files
saberkuntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 285765110
1 parent 722d9e5 commit 913640d

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

official/transformer/model/beam_search.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,16 @@ def _search_step(self, state):
323323
new state dictionary.
324324
"""
325325
# Grow alive sequences by one token.
326-
new_seq, new_log_probs, new_cache = self._grow_alive_seq(state)
326+
new_seq, new_log_probs, topk_ids, new_cache = self._grow_alive_seq(state)
327+
new_finished_flags = tf.equal(topk_ids, self.eos_id)
327328
# Collect top beam_size alive sequences
328-
alive_state = self._get_new_alive_state(new_seq, new_log_probs, new_cache)
329+
alive_state = self._get_new_alive_state(new_seq, new_log_probs,
330+
new_finished_flags, new_cache)
329331

330332
# Combine newly finished sequences with existing finished sequences, and
331333
# collect the top k scoring sequences.
332-
finished_state = self._get_new_finished_state(state, new_seq, new_log_probs)
334+
finished_state = self._get_new_finished_state(state, new_seq, new_log_probs,
335+
new_finished_flags)
333336

334337
# Increment loop index and create new state dictionary
335338
new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
@@ -407,18 +410,20 @@ def _grow_alive_seq(self, state):
407410
tf.expand_dims(topk_ids, axis=0))
408411
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
409412
else:
410-
topk_ids = tf.expand_dims(topk_ids, axis=2)
411-
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
412-
return topk_seq, topk_log_probs, new_cache
413+
topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
414+
return topk_seq, topk_log_probs, topk_ids, new_cache
413415

414-
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
416+
def _get_new_alive_state(self, new_seq, new_log_probs, new_finished_flags,
417+
new_cache):
415418
"""Gather the top k sequences that are still alive.
416419
417420
Args:
418421
new_seq: New sequences generated by growing the current alive sequences
419422
int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
420-
new_log_probs: Log probabilities of new sequences
421-
float32 tensor with shape [batch_size, beam_size]
423+
new_log_probs: Log probabilities of new sequences float32 tensor with
424+
shape [batch_size, beam_size]
425+
new_finished_flags: A boolean Tensor indicates which sequences are live
426+
inside the beam.
422427
new_cache: Dict of cached values for each sequence.
423428
424429
Returns:
@@ -428,7 +433,6 @@ def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
428433
Dict cache storing decoder states for top alive sequences}
429434
"""
430435
# To prevent finished sequences from being considered, set log probs to -inf
431-
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
432436
new_log_probs += tf.cast(new_finished_flags, self.dtype) * -inf(self.dtype)
433437

434438
top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
@@ -441,15 +445,18 @@ def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
441445
_StateKeys.ALIVE_CACHE: top_alive_cache
442446
}
443447

444-
def _get_new_finished_state(self, state, new_seq, new_log_probs):
448+
def _get_new_finished_state(self, state, new_seq, new_log_probs,
449+
new_finished_flags):
445450
"""Combine new and old finished sequences, and gather the top k sequences.
446451
447452
Args:
448453
state: A dictionary with the current loop state.
449454
new_seq: New sequences generated by growing the current alive sequences
450455
int32 tensor with shape [batch_size, beam_size, i + 1]
451-
new_log_probs: Log probabilities of new sequences
452-
float32 tensor with shape [batch_size, beam_size]
456+
new_log_probs: Log probabilities of new sequences float32 tensor with
457+
shape [batch_size, beam_size]
458+
new_finished_flags: A boolean Tensor indicates which sequences are live
459+
inside the beam.
453460
454461
Returns:
455462
Dictionary with finished keys from _StateKeys:
@@ -476,7 +483,6 @@ def _get_new_finished_state(self, state, new_seq, new_log_probs):
476483
new_scores = new_log_probs / length_norm
477484

478485
# Set the scores of the still-alive seq in new_seq to large negative values.
479-
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
480486
new_scores += ((1. - tf.cast(new_finished_flags, self.dtype)) *
481487
-inf(self.dtype))
482488

0 commit comments

Comments
 (0)