@@ -275,6 +275,23 @@ def _update_padded_sequence(
275275 # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
276276
277277
278+ # Decode token id into string and print it
279+ def _decode_in_flight (token , tokenizer , tp_rank ):
280+ # Make a 2D tensor with ids on row dimension
281+ # unsqueezed = torch.unsqueeze(token, 1)
282+ # token_str = tokenizer.decode(unsqueezed.tolist())
283+ # tiktoken does not accept tensor inputs
284+ decoding_list = token .tolist ()
285+ token_str = tokenizer .decode (decoding_list )
286+
287+ # print the token string on tp rank 0
288+ if tp_rank == 0 :
289+ logger .info (
290+ f"{ color .green } responses ====>>>> "
291+ f"{ color .blue } { token_str } { color .reset } "
292+ )
293+
294+
278295def _cleanup ():
279296 dist .barrier ()
280297 dist .destroy_process_group ()
@@ -444,11 +461,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
444461 padded_sequence , prompt_lengths = _create_padded_prompts (
445462 input_ids , tokenizer , seqlen_prefill , start_pos , device
446463 )
447- # TODO: figure out how to set input_pos for each prompt in the batch then we
448- # can remove this limitation.
449- # s = set(prompt_lengths)
450- logger .info (f"prompt_lengths = { prompt_lengths = } " )
451- # assert len(s) == 1, f"prompt_lengths should be the same, got {s}"
452464
453465 # Need these global ids due to the API definition of dist.send and recv
454466 first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
@@ -477,39 +489,14 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
477489 logger .info (
478490 f"{ color .green } Prefilling time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
479491 )
480- logger .info (
481- f"{ color .green } { input_pos = } , Prefill output: { type (output )= } { color .reset } "
482- )
483-
484- # Decode token id into string and print it
485- def decode_in_flight (token ):
486- # Make a 2D tensor with ids on row dimension
487- # unsqueezed = torch.unsqueeze(token, 1)
488- # token_str = tokenizer.decode(unsqueezed.tolist())
489- # tiktoken does not accept tensor inputs
490- decoding_list = token .tolist ()
491- logger .info (
492- f"{ color .red } decoding token { token = } { type (token )= } , { decoding_list = } { color .reset } "
493- )
494- token_str = tokenizer .decode (decoding_list )
495- logger .info (
496- f"{ color .green } prefile response ===>>>> token_str = { token_str } { color .reset } "
497- )
498- tp_rank = dist .get_rank ()
499- # print the token string on tp rank 0
500- if tp_rank == 0 :
501- logger .info (
502- f"{ color .green } responses ====>>>> "
503- f"{ color .blue } { token_str } { color .reset } "
504- )
505492
506493 # Decode the output -- first generated token
507494 if pp_rank == last_pp_rank :
508495 logger .info (f"{ color .green } Decoding...{ prompt_lengths = } { color .reset } " )
509496 new_token = _batch_decode_next_tokens (output , prompt_lengths )
510497 res .append (new_token )
511498 if not args .disable_in_flight_decode :
512- decode_in_flight (new_token )
499+ _decode_in_flight (new_token , tokenizer , tp_rank )
513500
514501 # seqlen = 1 now
515502 seqlen_decode = 1
@@ -559,11 +546,11 @@ def decode_in_flight(token):
559546
560547 # Decode the output
561548 if pp_rank == last_pp_rank :
562- logger .info (f"{ color .red } Decoding...{ output .shape = } { color .reset } " )
549+ # logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
563550 new_token = _batch_decode_next_tokens (output , prompt_lengths , step )
564551 res .append (new_token )
565552 if not args .disable_in_flight_decode :
566- decode_in_flight (new_token )
553+ _decode_in_flight (new_token , tokenizer , tp_rank )
567554
568555 # Increment input position
569556 input_pos += 1
0 commit comments