@@ -187,30 +187,23 @@ def _create_padded_prompts(
187187
188188def _batch_decode_next_tokens (
189189 output : torch .Tensor ,
190- tokenizer ,
191- prompt_lengths : Optional [List [int ]] = None ,
192- ) -> List [Tuple [int , str ]]:
190+ pos : int ,
191+ ) -> torch .Tensor :
193192 """
194193 Decode the next token for each prompt in the batch.
194+ Args:
195+ output (torch.Tensor): The output tensor to decode.
196+ pos: the position of the `output` to decode in the sequence length dimension.
195197
196198 Returns:
197- List[Tuple[int, str]]: List of tuples containing the next token id and its
198- decoded string for each prompt in the batch.
199+ Decoded token ids.
199200 """
200- batch_size = output .shape [0 ]
201- results = []
202-
203- for i in range (batch_size ):
204- pos = prompt_lengths [i ] - 1 if prompt_lengths is not None else 0
205- next_token_logits = output [i , pos , :]
206-
207- # Argmax (deterministic) TODO: add temperature
208- next_token = torch .argmax (next_token_logits , dim = - 1 )
209-
210- next_token_decoded = tokenizer .decode ([next_token .item ()])
211- results .append ((next_token .item (), next_token_decoded ))
212-
213- return results
201+ # Take the next token logits for each prompt
202+ next_token_logits = output [:, pos , :]
203+ # Argmax (deterministic) TODO: add temperature
204+ next_token = torch .argmax (next_token_logits , dim = - 1 )
205+ # Token ids in int tensor form
206+ return next_token
214207
215208
216209def _update_padded_sequence (
@@ -401,8 +394,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
401394 # New token generated each iteration
402395 # need a row dimension for each prompt in the batch
403396 new_token = torch .zeros (batch_size , 1 , device = device , dtype = torch .int64 )
404- res = [[] for _ in range ( batch_size )]
405- num_tokens = 40
397+ # Store the generated tokens
398+ res = []
406399
407400 # Prefill phase
408401 # Run context input through pipeline
@@ -421,23 +414,24 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
421414 f"{ color .green } Prefilling time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
422415 )
423416
424- # Decode the output -- first generated token
425- if pp_rank == last_pp_rank :
426- decode_results = _batch_decode_next_tokens (
427- output = output ,
428- tokenizer = tokenizer ,
429- prompt_lengths = prompt_lengths ,
430- )
431- for i in range (len (decode_results )):
432- new_token [i , 0 ] = torch .tensor (
433- [decode_results [i ][0 ]], device = device
434- ) # token_id in int form
417+ # Decode token id into string and print it
418+ def decode_in_flight (token ):
419+ # Make a 2D tensor with ids on row dimension
420+ unsqueezed = torch .unsqueeze (token , 1 )
421+ token_str = tokenizer .decode (unsqueezed .tolist ())
435422 if tp_rank == 0 :
436423 logger .info (
437- f"{ color .green } { '* Prefill *' } "
438- f"responses ====>>>> { color .blue } { decode_results = } { color .reset } "
424+ f"{ color .green } responses ====>>>> "
425+ f"{ color .blue } { token_str } { color .reset } "
439426 )
440427
428+ # Decode the output -- first generated token
429+ if pp_rank == last_pp_rank :
430+ new_token = _batch_decode_next_tokens (output , prompt_lengths [0 ] - 1 )
431+ res .append (new_token )
432+ if not args .disable_in_flight_decode :
433+ decode_in_flight (new_token )
434+
441435 # seqlen = 1 now
442436 seqlen_decode = 1
443437 input_pos = torch .tensor ([prompt_lengths [0 ]], device = device )
@@ -459,7 +453,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
459453
460454 # Decoding
461455 with torch .no_grad (), CUDATrackTime () as timer :
462- for step in range (num_tokens - 1 ):
456+ for step in range (args . ntokens - 1 ):
463457 kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
464458 # sendrecv between last and first ranks, only if:
465459 # first_pp_rank != last_pp_rank.
@@ -486,21 +480,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
486480
487481 # Decode the output
488482 if pp_rank == last_pp_rank :
489- decode_results = _batch_decode_next_tokens (
490- output = output , tokenizer = tokenizer
491- )
492- if tp_rank == 0 :
493- logger .info (
494- f"{ color .green } { '* Decode *' } "
495- f"responses ====>>>> { color .blue } { decode_results = } { color .reset } "
496- )
497- # decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
498- for i in range (len (decode_results )):
499- res [i ].append (decode_results [i ][1 ])
500- new_token [i , 0 ] = torch .tensor (
501- [decode_results [i ][0 ]], device = device
502- ) # decode_results[i][0]
483+ new_token = _batch_decode_next_tokens (output , 0 )
484+ res .append (new_token )
485+ if not args .disable_in_flight_decode :
486+ decode_in_flight (new_token )
503487
488+ # Increment input position
504489 input_pos += 1
505490
506491 logger .info (
@@ -511,21 +496,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
511496
512497 # output formatted response via last pp group and tp rank 0
513498 if pp_rank == last_pp_rank and tp_rank == 0 :
514- for i in range (len (prompt_lengths )):
515- logger .info (f"\n Prompt:{ color .green } { prompt [i ]} { color .reset } " )
516-
517- # TODO: resolve issue with llama2-7b-chat model and "".join
518- if model_name != "llama2-7b-chat" :
519- formatted_response = "" .join (res [i ])
520- else :
521- formatted_response = " " .join (res [i ])
522- logger .info (f"$$ { color .red } { formatted_response } { color .reset } $$\n " )
499+ # `res` is a list of tensors, each being a batch of generated token ids
500+ res = torch .stack (res , dim = 1 )
501+ res_list = res .tolist ()
502+ response = tokenizer .decode (res_list )
503+ for i in range (len (response )):
504+ logger .info (f"$$ { color .red } { response [i ]} { color .reset } $$\n " )
523505
524506 # Cleanup
507+ _cleanup ()
525508 logger .info (
526509 f"{ color .green } Success{ color .white } - { color .blue } Rank { rank } has completed.{ color .reset } "
527510 )
528- _cleanup ()
529511
530512
531513if __name__ == "__main__" :
@@ -537,6 +519,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
537519 choices = NAME_TO_DISTRIBUTION_AND_DTYPE .keys (),
538520 )
539521 parser .add_argument ("--pp" , type = int , default = 1 , help = "Pipeline parallel degree" )
522+ parser .add_argument (
523+ "--ntokens" ,
524+ type = int ,
525+ default = 40 ,
526+ help = "Number of tokens to generate" ,
527+ )
528+ parser .add_argument (
529+ "--disable-in-flight-decode" ,
530+ action = "store_true" ,
531+ default = False ,
532+ help = "Whether to decode token into string in flight" ,
533+ )
540534 args = parser .parse_args ()
541535
542536 main (args )
0 commit comments