@@ -1508,6 +1508,7 @@ def decode(self,
15081508 beam_size = 1 ,
15091509 alpha = 0.6 ,
15101510 temperature = 0.0 ,
1511+ sampling_keep_top_k = - 1 ,
15111512 decode_length_multiplier = 1.5 ,
15121513 decode_length_constant = 10 ,
15131514 max_decode_length = None ):
@@ -1523,6 +1524,8 @@ def decode(self,
15231524 alpha: a floating point value (length bonus for beam search)
15241525 temperature: a value between 0 and 1 (must be 0 if beam_size > 1)
15251526 0.0 means argmax, 1.0 means sample according to predicted distribution.
1527+ sampling_keep_top_k: a value between 1 and vocab_size used to sample from
1528+ only the k most likely logits. Set to -1 to sample from all logits.
15261529 decode_length_multiplier: a float
15271530 decode_length_constant: a float
15281531 max_decode_length: an optional integer
@@ -1558,6 +1561,7 @@ def decode(self,
15581561 return self .decoder .sample_autoregressive (
15591562 partial_sequences ,
15601563 temperature = temperature ,
1564+ sampling_keep_top_k = sampling_keep_top_k ,
15611565 variable_dtype = variable_dtype ,
15621566 encoder_output = encoder_output ,
15631567 encoder_sequence_id = encoder_sequence_id ,
@@ -1569,6 +1573,9 @@ def decode(self,
15691573 if temperature != 0 :
15701574 raise ValueError (
15711575 "don't know how to beam search with nonzero temperature" )
1576+ if sampling_keep_top_k != - 1 :
1577+ raise ValueError (
1578+ "don't know how to beam search with top-k value other than -1." )
15721579 # beam search
15731580 beam_dim = mtf .Dimension ("beam" , beam_size )
15741581 ids_shape = mtf .Shape (batch_dims + [beam_dim , decode_length_dim ])
0 commit comments