@@ -203,7 +203,7 @@ def _batch_decode_next_tokens(
203203 # Take the next token logits for each prompt
204204 next_token_logits = output [:, pos , :]
205205 # Argmax (deterministic) TODO: add temperature
206- next_token = torch .argmax (next_token_logits , dim = - 1 )
206+ next_token = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
207207 # Token ids in int tensor form
208208 return next_token
209209
@@ -418,9 +418,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
418418
419419 # Decode token id into string and print it
420420 def decode_in_flight (token ):
421- # Make a 2D tensor with ids on row dimension
422- unsqueezed = torch .unsqueeze (token , 1 )
423- token_str = tokenizer .decode (unsqueezed .tolist ())
421+ token_str = tokenizer .decode (token .tolist ())
424422 if tp_rank == 0 :
425423 logger .info (
426424 f"{ color .green } responses ====>>>> "
@@ -498,8 +496,10 @@ def decode_in_flight(token):
498496
499497 # output formatted response via last pp group and tp rank 0
500498 if pp_rank == last_pp_rank and tp_rank == 0 :
501- # `res` is a list of tensors, each being a batch of generated token ids
502- res = torch .stack (res , dim = 1 )
499+ # `res` is a list of tensors, each being a batch of generated token ids.
500+ # We need to concatenate them to get the full sequence of generated
501+ # token ids. Thus cat'ing along dim 1.
502+ res = torch .cat (res , dim = 1 )
503503 res_list = res .tolist ()
504504 response = tokenizer .decode (res_list )
505505 for i in range (len (response )):
0 commit comments