@@ -209,6 +209,7 @@ def _batch_decode_next_tokens(
209209    batch_size , seq_len , vocab_size  =  output .shape 
210210
211211    if  step  !=  - 1 :
212+         # `pos` is not provided, so we can use the first token 
212213        next_token_logits  =  output [:, 0 , :]
213214    else :
214215        # get the logits for each prompt at the specified positions 
@@ -228,9 +229,9 @@ def _batch_decode_next_tokens(
228229        ).squeeze (- 1 )
229230    else :
230231        # Argmax (deterministic) 
231-         next_tokens  =  torch .argmax (next_token_logits , dim = - 1 )
232+         next_tokens  =  torch .argmax (next_token_logits , dim = - 1 ,  keepdim = True )
232233
233-     logger . info ( f" { color . yellow } Next tokens:  { color . blue } { next_tokens } { color . reset } " ) 
234+     # Token ids in int tensor form 
234235    return  next_tokens 
235236
236237
@@ -247,6 +248,11 @@ def _update_padded_sequence(
247248# Decode token id into string and print it 
248249def  _decode_in_flight (token , tokenizer , tp_rank ):
249250    """decode token ids for all prompts in the batch and log them""" 
251+     # `token` is a tensor of shape (batch_size, 1). 
252+     # For TiktokenTokenizer, we need to squeeze it to 1D. 
253+     # For SentencePieceProcessor, we don't. 
254+     if  isinstance (tokenizer , TiktokenTokenizer ):
255+         token  =  torch .squeeze (token , dim = 1 )
250256    token_str  =  tokenizer .decode (token .tolist ())
251257    # print the token string on tp rank 0 
252258    if  tp_rank  ==  0 :
@@ -530,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
530536
531537    # output formatted response via last pp group and tp rank 0 
532538    if  pp_rank  ==  last_pp_rank  and  tp_rank  ==  0 :
533-         # `res` is a list of tensors, each being a batch of generated token ids 
534- 
535-         res_stacked  =  torch .stack (res , dim = 1 )
536-         res_list  =  res_stacked .tolist ()
537- 
538-         # Decode the output as comprehension instead of loop 
539-         responses  =  [tokenizer .decode (sequence ) for  sequence  in  res_list ]
540- 
539+         # `res` is a list of tensors, each being a batch of generated token ids. 
540+         # We need to concatenate them to get the full sequence of generated 
541+         # token ids. Thus cat'ing along dim 1. 
542+         res  =  torch .cat (res , dim = 1 )
543+         res_list  =  res .tolist ()
544+         responses  =  tokenizer .decode (res_list )
541545        # Show prompts and responses 
542546        for  prompt_text , response_text  in  zip (prompt , responses ):
543547            logger .info (f"Prompt: { color .green } { prompt_text }   { color .reset }  " )
0 commit comments