Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 8b721b0

Browse files
nshazeerCopybara-Service
authored andcommitted
MTF beam search - gather slices in one operation instead of two. Beam search is still slow on TPU, possibly due to bad XLA layout choices.
PiperOrigin-RevId: 224538842
1 parent 4a3e81f commit 8b721b0

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

mesh_tensorflow/beam_search.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,17 @@
2929

3030

3131
def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags,
32-
beam_dim, prefix="default",
33-
states=None):
32+
beam_dim, prefix="default"):
3433
"""Given sequences and scores, will gather the top k=beam size sequences.
3534
3635
This function is used to grow alive, and finished. It takes sequences,
3736
scores, and flags, and returns the top k from sequences, scores_to_gather,
3837
and flags based on the values in scores.
3938
40-
This method permits easy introspection using tfdbg. It adds three named ops
39+
This method permits easy introspection using tfdbg. It adds two named ops
4140
that are prefixed by `prefix`:
4241
- _topk_seq: the tensor for topk_seq returned by this method.
4342
- _topk_flags: the tensor for topk_finished_flags returned by this method.
44-
- _topk_scores: the tensor for tokp_gathered_scores returned by this method.
4543
4644
Args:
4745
sequences: Tensor of sequences that we need to gather from.
@@ -57,17 +55,18 @@ def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags,
5755
EOS or not
5856
beam_dim: mtf.Dimension
5957
prefix: an optional string
60-
states: an optional list of mtf.Tensor
6158
Returns:
6259
Tuple of
6360
(topk_seq [batch_size, beam_size, decode_length],
6461
topk_gathered_scores [batch_size, beam_size],
6562
topk_finished_flags[batch_size, beam_size],
66-
topk_gathered_states)
63+
selector)
6764
"""
6865
unused_batch_dim, old_beam_dim, unused_length_dim = sequences.shape.dims
6966
topk_indices, _ = mtf.top_k(scores, old_beam_dim, beam_dim)
7067

68+
selector = mtf.one_hot(topk_indices, old_beam_dim, dtype=tf.float32)
69+
7170
# Gather up the highest scoring sequences.
7271
# For each operation added, give it
7372
# a concrete name to simplify observing these operations with tfdbg.
@@ -81,11 +80,7 @@ def gather(tensor, name):
8180
topk_seq = gather(sequences, "_seq")
8281
topk_flags = gather(flags, "_flags")
8382
topk_gathered_scores = gather(scores_to_gather, "_scores")
84-
if states is None:
85-
topk_gathered_states = None
86-
else:
87-
topk_gathered_states = [gather(state, "_topk_states") for state in states]
88-
return topk_seq, topk_gathered_scores, topk_flags, topk_gathered_states
83+
return topk_seq, topk_gathered_scores, topk_flags, selector
8984

9085

9186
def beam_search(logits_fn,
@@ -213,9 +208,9 @@ def _my_concat(a, b):
213208
curr_finished_flags = _my_concat(finished_flags, curr_finished)
214209
return compute_topk_scores_and_seq(
215210
curr_finished_seq, curr_finished_scores, curr_finished_scores,
216-
curr_finished_flags, beam_dim, "grow_finished", states=None)
211+
curr_finished_flags, beam_dim, "grow_finished")
217212

218-
def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states):
213+
def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished):
219214
"""Given sequences and scores, will gather the top k=beam size sequences.
220215
221216
Args:
@@ -226,7 +221,6 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states):
226221
[batch, beam]
227222
curr_finished: Finished flags for each of these sequences.
228223
[batch, beam]
229-
states: list of mtf.Tensor
230224
Returns:
231225
Tuple of
232226
(Topk sequences based on scores,
@@ -238,7 +232,7 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states):
238232
curr_scores += mtf.cast(curr_finished, curr_scores.dtype) * -INF
239233
return compute_topk_scores_and_seq(curr_seq, curr_scores, curr_log_probs,
240234
curr_finished, beam_dim,
241-
"grow_alive", states)
235+
"grow_alive")
242236

243237
def grow_topk(i, alive_seq, alive_log_probs, states=None):
244238
r"""Inner beam search loop.
@@ -298,6 +292,8 @@ def grow_topk(i, alive_seq, alive_log_probs, states=None):
298292
top_beam_index = top_ids // vocab_dim.size
299293
top_ids %= vocab_dim.size # Unflatten the ids
300294

295+
selector = mtf.one_hot(top_beam_index, beam_dim, dtype=tf.float32)
296+
301297
def my_gather(tensor):
302298
return mtf.gather(
303299
tensor, top_beam_index, beam_dim,
@@ -308,14 +304,12 @@ def my_gather(tensor):
308304
# bools
309305
top_seq = my_gather(alive_seq)
310306

311-
if states:
312-
states = [my_gather(state) for state in new_states]
313-
314307
# Append the most probable alive
315308
top_seq += top_ids * mtf.one_hot(i, length_dim, dtype=tf.int32)
316309
top_finished = mtf.equal(top_ids, eos_id)
317310

318-
return top_seq, top_log_probs, top_scores, top_finished, states
311+
return (
312+
top_seq, top_log_probs, top_scores, top_finished, new_states, selector)
319313

320314
def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores,
321315
finished_flags, *states):
@@ -368,14 +362,26 @@ def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores,
368362
# 2. Extract the ones that have finished and haven't finished
369363
# 3. Recompute the contents of finished based on scores.
370364
(top2k_seq, top2k_log_probs, top2k_scores, top2k_finished,
371-
top2k_states) = grow_topk(i, alive_seq, alive_log_probs, states)
372-
alive_seq, alive_log_probs, _, states = grow_alive(
373-
top2k_seq, top2k_scores, top2k_log_probs, top2k_finished, top2k_states)
365+
new_states, first_selector) = grow_topk(
366+
i, alive_seq, alive_log_probs, states)
367+
alive_seq, alive_log_probs, _, second_selector = grow_alive(
368+
top2k_seq, top2k_scores, top2k_log_probs, top2k_finished)
374369
finished_seq, finished_scores, finished_flags, _ = grow_finished(
375370
finished_seq, finished_scores, finished_flags, top2k_seq, top2k_scores,
376371
top2k_finished)
372+
old_beam_dim = mtf.Dimension("old_beam", beam_dim.size)
373+
selector = mtf.einsum(
374+
[mtf.rename_dimension(first_selector, beam_dim.name, old_beam_dim.name),
375+
second_selector],
376+
output_shape=[batch_dim, old_beam_dim, beam_dim])
377+
new_states = [
378+
mtf.einsum(
379+
[mtf.rename_dimension(state, beam_dim.name, old_beam_dim.name),
380+
mtf.cast(selector, state.dtype)],
381+
reduced_dims=[old_beam_dim], output_shape=state.shape)
382+
for state in new_states]
377383
return (i + 1, alive_seq, alive_log_probs, finished_seq, finished_scores,
378-
finished_flags) + tuple(states)
384+
finished_flags) + tuple(new_states)
379385

380386
def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
381387
finished_scores, finished_in_finished, *unused_states):

0 commit comments

Comments
 (0)