@@ -187,30 +187,23 @@ def _create_padded_prompts(
187187
188188def  _batch_decode_next_tokens (
189189    output : torch .Tensor ,
190-     tokenizer ,
191-     prompt_lengths : Optional [List [int ]] =  None ,
192- ) ->  List [Tuple [int , str ]]:
190+     pos : int ,
191+ ) ->  torch .Tensor :
193192    """ 
194193    Decode the next token for each prompt in the batch. 
194+     Args: 
195+         output (torch.Tensor): The output tensor to decode. 
196+         pos: the position of the `output` to decode in the sequence length dimension. 
195197
196198    Returns: 
197-         List[Tuple[int, str]]: List of tuples containing the next token id and its 
198-         decoded string for each prompt in the batch. 
199+         Decoded token ids. 
199200    """ 
200-     batch_size  =  output .shape [0 ]
201-     results  =  []
202- 
203-     for  i  in  range (batch_size ):
204-         pos  =  prompt_lengths [i ] -  1  if  prompt_lengths  is  not None  else  0 
205-         next_token_logits  =  output [i , pos , :]
206- 
207-         # Argmax (deterministic) TODO: add temperature 
208-         next_token  =  torch .argmax (next_token_logits , dim = - 1 )
209- 
210-         next_token_decoded  =  tokenizer .decode ([next_token .item ()])
211-         results .append ((next_token .item (), next_token_decoded ))
212- 
213-     return  results 
201+     # Take the next token logits for each prompt 
202+     next_token_logits  =  output [:, pos , :]
203+     # Argmax (deterministic) TODO: add temperature 
204+     next_token  =  torch .argmax (next_token_logits , dim = - 1 )
205+     # Token ids in int tensor form 
206+     return  next_token 
214207
215208
216209def  _update_padded_sequence (
@@ -399,11 +392,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
399392    last_pp_rank_global_id  =  dist .get_global_rank (pp_group , last_pp_rank )
400393
401394    # New token generated each iteration 
402-     total_prompts  =  len (prompt_lengths )
403-     # need a new token dimension (row) for each prompt in the batch 
404-     new_token  =  torch .zeros (total_prompts , 1 , device = device , dtype = torch .int64 )
405-     res  =  [[] for  _  in  range (total_prompts )]
406-     num_tokens  =  40 
395+     # need a row dimension for each prompt in the batch 
396+     new_token  =  torch .zeros (batch_size , 1 , device = device , dtype = torch .int64 )
397+     # Store the generated tokens 
398+     res  =  []
407399
408400    # Prefill phase 
409401    # Run context input through pipeline 
@@ -422,23 +414,24 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
422414        f"{ color .green } { timer .get_time ()} { timer .unit } { rank } { color .reset }  
423415    )
424416
425-     # Decode the output -- first generated token 
426-     if  pp_rank  ==  last_pp_rank :
427-         decode_results  =  _batch_decode_next_tokens (
428-             output = output ,
429-             tokenizer = tokenizer ,
430-             prompt_lengths = prompt_lengths ,
431-         )
432-         for  i  in  range (len (decode_results )):
433-             new_token [i , 0 ] =  torch .tensor (
434-                 [decode_results [i ][0 ]], device = device 
435-             )  # token_id in int form 
417+     # Decode token id into string and print it 
418+     def  decode_in_flight (token ):
419+         # Make a 2D tensor with ids on row dimension 
420+         unsqueezed  =  torch .unsqueeze (token , 1 )
421+         token_str  =  tokenizer .decode (unsqueezed .tolist ())
436422        if  tp_rank  ==  0 :
437423            logger .info (
438-                 f"{ color .green } { '* Prefill *' }  
439-                 f"responses ====>>>>  { color .blue } { decode_results = } { color .reset }  
424+                 f"{ color .green } responses ====>>>>  " 
425+                 f"{ color .blue } { token_str }   { color .reset }  
440426            )
441427
428+     # Decode the output -- first generated token 
429+     if  pp_rank  ==  last_pp_rank :
430+         new_token  =  _batch_decode_next_tokens (output , prompt_lengths [0 ] -  1 )
431+         res .append (new_token )
432+         if  not  args .disable_in_flight_decode :
433+             decode_in_flight (new_token )
434+ 
442435    # seqlen = 1 now 
443436    seqlen_decode  =  1 
444437    input_pos  =  torch .tensor ([prompt_lengths [0 ]], device = device )
@@ -460,7 +453,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
460453
461454    # Decoding 
462455    with  torch .no_grad (), CUDATrackTime () as  timer :
463-         for  step  in  range (num_tokens  -  1 ):
456+         for  step  in  range (args . ntokens  -  1 ):
464457            kwargs  =  {"input_pos" : input_pos , "cache_lane" : lane }
465458            # sendrecv between last and first ranks, only if: 
466459            # first_pp_rank != last_pp_rank. 
@@ -487,21 +480,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487480
488481            # Decode the output 
489482            if  pp_rank  ==  last_pp_rank :
490-                 decode_results  =  _batch_decode_next_tokens (
491-                     output = output , tokenizer = tokenizer 
492-                 )
493-                 if  tp_rank  ==  0 :
494-                     logger .info (
495-                         f"{ color .green } { '* Decode *' }  
496-                         f"responses ====>>>> { color .blue } { decode_results = } { color .reset }  
497-                     )
498-                 # decode results returns both token_id (int) and token_str (readable), hence [0] and [1] 
499-                 for  i  in  range (len (decode_results )):
500-                     res [i ].append (decode_results [i ][1 ])
501-                     new_token [i , 0 ] =  torch .tensor (
502-                         [decode_results [i ][0 ]], device = device 
503-                     )  # decode_results[i][0] 
483+                 new_token  =  _batch_decode_next_tokens (output , 0 )
484+                 res .append (new_token )
485+                 if  not  args .disable_in_flight_decode :
486+                     decode_in_flight (new_token )
504487
488+             # Increment input position 
505489            input_pos  +=  1 
506490
507491    logger .info (
@@ -512,21 +496,19 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
512496
513497    # output formatted response via last pp group and tp rank 0 
514498    if  pp_rank  ==  last_pp_rank  and  tp_rank  ==  0 :
515-         for  i  in  range (len (prompt_lengths )):
516-             logger .info (f"\n Prompt:{ color .green } { prompt [i ]} { color .reset }  )
517- 
518-             # TODO: resolve issue with llama2-7b-chat model and "".join 
519-             if  model_name  !=  "llama2-7b-chat" :
520-                 formatted_response  =  "" .join (res [i ])
521-             else :
522-                 formatted_response  =  " " .join (res [i ])
523-             logger .info (f"$$ { color .red } { formatted_response } { color .reset } \n " )
499+         # `res` is a list of tensors, each being a batch of generated token ids 
500+         res  =  torch .stack (res , dim = 1 )
501+         res_list  =  res .tolist ()
502+         response  =  tokenizer .decode (res_list )
503+         for  i  in  range (len (response )):
504+             logger .info (f"Prompt: { color .green } { prompt [i ]} { color .reset }  )
505+             logger .info (f"Response: { color .red } { response [i ]} { color .reset }  )
524506
525507    # Cleanup 
508+     _cleanup ()
526509    logger .info (
527510        f"{ color .green } { color .white } { color .blue } { rank } { color .reset }  
528511    )
529-     _cleanup ()
530512
531513
532514if  __name__  ==  "__main__" :
@@ -538,6 +520,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
538520        choices = NAME_TO_DISTRIBUTION_AND_DTYPE .keys (),
539521    )
540522    parser .add_argument ("--pp" , type = int , default = 1 , help = "Pipeline parallel degree" )
523+     parser .add_argument (
524+         "--ntokens" ,
525+         type = int ,
526+         default = 40 ,
527+         help = "Number of tokens to generate" ,
528+     )
529+     parser .add_argument (
530+         "--disable-in-flight-decode" ,
531+         action = "store_true" ,
532+         default = False ,
533+         help = "Whether to decode token into string in flight" ,
534+     )
541535    args  =  parser .parse_args ()
542536
543537    main (args )
0 commit comments