@@ -209,6 +209,7 @@ def _batch_decode_next_tokens(
209209 batch_size , seq_len , vocab_size = output .shape
210210
211211 if step != - 1 :
212+ # `pos` is not provided, so we can use the first token
212213 next_token_logits = output [:, 0 , :]
213214 else :
214215 # get the logits for each prompt at the specified positions
@@ -228,9 +229,9 @@ def _batch_decode_next_tokens(
228229 ).squeeze (- 1 )
229230 else :
230231 # Argmax (deterministic)
231- next_tokens = torch .argmax (next_token_logits , dim = - 1 )
232+ next_tokens = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
232233
233- logger . info ( f" { color . yellow } Next tokens: { color . blue } { next_tokens } { color . reset } " )
234+ # Token ids in int tensor form
234235 return next_tokens
235236
236237
@@ -247,6 +248,11 @@ def _update_padded_sequence(
247248# Decode token id into string and print it
248249def _decode_in_flight (token , tokenizer , tp_rank ):
249250 """decode token ids for all prompts in the batch and log them"""
251+ # `token` is a tensor of shape (batch_size, 1).
252+ # For TiktokenTokenizer, we need to squeeze it to 1D.
253+ # For SentencePieceProcessor, we don't.
254+ if isinstance (tokenizer , TiktokenTokenizer ):
255+ token = torch .squeeze (token , dim = 1 )
250256 token_str = tokenizer .decode (token .tolist ())
251257 # print the token string on tp rank 0
252258 if tp_rank == 0 :
@@ -530,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
530536
531537 # output formatted response via last pp group and tp rank 0
532538 if pp_rank == last_pp_rank and tp_rank == 0 :
533- # `res` is a list of tensors, each being a batch of generated token ids
534-
535- res_stacked = torch .stack (res , dim = 1 )
536- res_list = res_stacked .tolist ()
537-
538- # Decode the output as comprehension instead of loop
539- responses = [tokenizer .decode (sequence ) for sequence in res_list ]
540-
539+ # `res` is a list of tensors, each being a batch of generated token ids.
540+ # We need to concatenate them to get the full sequence of generated
541+ # token ids. Thus cat'ing along dim 1.
542+ res = torch .cat (res , dim = 1 )
543+ res_list = res .tolist ()
544+ responses = tokenizer .decode (res_list )
541545 # Show prompts and responses
542546 for prompt_text , response_text in zip (prompt , responses ):
543547 logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
0 commit comments