@@ -189,43 +189,50 @@ def _create_padded_prompts(
189189
190190
191191def _batch_decode_next_tokens (
192- output : torch .Tensor , pos : List [int ], step : int = - 1 , temperature : float = 1.0
192+ output : torch .Tensor ,
193+ pos : List [int ],
194+ step : int = - 1 ,
195+ temperature : float = 1.0 ,
196+ topk : int = 10 ,
193197) -> torch .Tensor :
194198 """
195- Decode the next token for each prompt in the batch.
199+ Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding.
200+
196201 Args:
197202 output (torch.Tensor): The output tensor to decode.
198- pos: the position of the `output` to decode in the sequence length dimension.
203+ pos (List[int]): The positions of the `output` to decode in the sequence length dimension.
204+ step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token.
205+ temperature (float): Sampling temperature for non-deterministic decoding.
199206
200207 Returns:
201- Decoded token ids.
208+ torch.Tensor: Decoded token ids.
202209 """
203- # Take the next token logits for each prompt
204- res = []
205- # logger.info(f"{color.green}output shape = {output.shape}{color.reset}")
206- # logger.info(f"{color.green}pos = {pos}{color.reset}")
207210 batch_size , seq_len , vocab_size = output .shape
208211
209212 if step != - 1 :
210213 next_token_logits = output [:, 0 , :]
211- next_token = torch .argmax (next_token_logits , dim = - 1 )
212- res .append (next_token )
213- res = torch .stack (res , dim = 0 )
214- res = res .squeeze (0 )
215214 else :
216- for i in range (batch_size ):
217- token_pos = pos [i ] - 1
218- next_token_logits = output [i , token_pos , :]
219-
220- # Argmax (deterministic) TODO: add temperature
221- next_token = torch .argmax (next_token_logits , dim = - 1 )
222- # logger.info(f"{color.blue}next_token = {next_token}{color.reset}")
223- res .append (next_token )
224- # Token ids in int tensor form
225- res = torch .stack (res , dim = 0 )
226-
227- logger .info (f"{ color .yellow } next_token = { res } { color .reset } " )
228- return res # next_token
215+ # get the logits for each prompt at the specified positions
216+ next_token_logits = output [torch .arange (batch_size ), torch .tensor (pos ) - 1 ]
217+
218+ if temperature != 1.0 :
219+ next_token_logits = next_token_logits / temperature
220+
221+ # Uses top-k sampling if temperature is not 1.0, otherwise use argmax
222+ if temperature != 1.0 :
223+ top_k = min (topk , vocab_size )
224+ top_k_logits , top_k_indices = torch .topk (next_token_logits , k = top_k , dim = - 1 )
225+ probs = torch .softmax (top_k_logits , dim = - 1 )
226+ next_token_indices = torch .multinomial (probs , num_samples = 1 ).squeeze (- 1 )
227+ next_tokens = top_k_indices .gather (
228+ - 1 , next_token_indices .unsqueeze (- 1 )
229+ ).squeeze (- 1 )
230+ else :
231+ # Argmax (deterministic)
232+ next_tokens = torch .argmax (next_token_logits , dim = - 1 )
233+
234+ logger .info (f"{ color .yellow } Next tokens: { color .blue } { next_tokens } { color .reset } " )
235+ return next_tokens
229236
230237
231238def _update_padded_sequence (
0 commit comments