@@ -442,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
442442 # New token generated each iteration
443443 # need a row dimension for each prompt in the batch
444444 new_token = torch .zeros (batch_size , 1 , device = device , dtype = torch .int64 )
445- logger .info (f"{ color .green } { new_token .shape = } , { new_token = } { color .reset } " )
446445 # Store the generated tokens
447446 res = []
448447
@@ -519,7 +518,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
519518
520519 # Decode the output
521520 if pp_rank == last_pp_rank :
522- # logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
523521 new_token = _batch_decode_next_tokens (output , prompt_lengths , step )
524522 res .append (new_token )
525523 if not args .disable_in_flight_decode :
@@ -541,7 +539,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
541539 # token ids. Thus cat'ing along dim 1.
542540 res = torch .cat (res , dim = 1 )
543541 res_list = res .tolist ()
544- responses = tokenizer .decode (res_list )
542+ if isinstance (tokenizer , TiktokenTokenizer ):
543+ # For TiktokenTokenizer, we need to decode prompt by prompt.
544+ # TODO: is there a better way to do this?
545+ responses = [tokenizer .decode (sequence ) for sequence in res_list ]
546+ else : # SentencePieceProcessor
547+ # For SentencePieceProcessor, we can decode the entire 2D list at once.
548+ responses = tokenizer .decode (res_list )
545549 # Show prompts and responses
546550 for prompt_text , response_text in zip (prompt , responses ):
547551 logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
0 commit comments