Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 20034bf

Browse files
daphneiMesh TensorFlow Team
authored andcommitted
Allow specifying value for top-k directly to Bitransformer decode function.
PiperOrigin-RevId: 325026516
1 parent 2fd4ecd commit 20034bf

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

mesh_tensorflow/transformer/transformer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,6 +1508,7 @@ def decode(self,
15081508
beam_size=1,
15091509
alpha=0.6,
15101510
temperature=0.0,
1511+
sampling_keep_top_k=-1,
15111512
decode_length_multiplier=1.5,
15121513
decode_length_constant=10,
15131514
max_decode_length=None):
@@ -1523,6 +1524,8 @@ def decode(self,
15231524
alpha: a floating point value (length bonus for beam search)
15241525
temperature: a value between 0 and 1 (must be 0 if beam_size > 1)
15251526
0.0 means argmax, 1.0 means sample according to predicted distribution.
1527+
sampling_keep_top_k: a value between 1 and vocab_size used to sample from
1528+
only the k most likely logits. Set to -1 to sample from all logits.
15261529
decode_length_multiplier: a float
15271530
decode_length_constant: a float
15281531
max_decode_length: an optional integer
@@ -1558,6 +1561,7 @@ def decode(self,
15581561
return self.decoder.sample_autoregressive(
15591562
partial_sequences,
15601563
temperature=temperature,
1564+
sampling_keep_top_k=sampling_keep_top_k,
15611565
variable_dtype=variable_dtype,
15621566
encoder_output=encoder_output,
15631567
encoder_sequence_id=encoder_sequence_id,
@@ -1569,6 +1573,9 @@ def decode(self,
15691573
if temperature != 0:
15701574
raise ValueError(
15711575
"don't know how to beam search with nonzero temperature")
1576+
if sampling_keep_top_k != -1:
1577+
raise ValueError(
1578+
"don't know how to beam search with top-k value other than -1.")
15721579
# beam search
15731580
beam_dim = mtf.Dimension("beam", beam_size)
15741581
ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim])

0 commit comments

Comments
 (0)