Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 41d61a8

Browse files
committed
update _decode_in_flight
1 parent a512141 commit 41d61a8

File tree

1 file changed

+20
-33
lines changed

1 file changed

+20
-33
lines changed

dist_run.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,23 @@ def _update_padded_sequence(
275275
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
276276

277277

278+
# Decode token id into string and print it
279+
def _decode_in_flight(token, tokenizer, tp_rank):
280+
# Make a 2D tensor with ids on row dimension
281+
# unsqueezed = torch.unsqueeze(token, 1)
282+
# token_str = tokenizer.decode(unsqueezed.tolist())
283+
# tiktoken does not accept tensor inputs
284+
decoding_list = token.tolist()
285+
token_str = tokenizer.decode(decoding_list)
286+
287+
# print the token string on tp rank 0
288+
if tp_rank == 0:
289+
logger.info(
290+
f"{color.green} responses ====>>>> "
291+
f"{color.blue} {token_str} {color.reset}"
292+
)
293+
294+
278295
def _cleanup():
279296
dist.barrier()
280297
dist.destroy_process_group()
@@ -444,11 +461,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
444461
padded_sequence, prompt_lengths = _create_padded_prompts(
445462
input_ids, tokenizer, seqlen_prefill, start_pos, device
446463
)
447-
# TODO: figure out how to set input_pos for each prompt in the batch then we
448-
# can remove this limitation.
449-
# s = set(prompt_lengths)
450-
logger.info(f"prompt_lengths = {prompt_lengths=}")
451-
# assert len(s) == 1, f"prompt_lengths should be the same, got {s}"
452464

453465
# Need these global ids due to the API definition of dist.send and recv
454466
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
@@ -477,39 +489,14 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
477489
logger.info(
478490
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
479491
)
480-
logger.info(
481-
f"{color.green}{input_pos=}, Prefill output: {type(output)=}{color.reset}"
482-
)
483-
484-
# Decode token id into string and print it
485-
def decode_in_flight(token):
486-
# Make a 2D tensor with ids on row dimension
487-
# unsqueezed = torch.unsqueeze(token, 1)
488-
# token_str = tokenizer.decode(unsqueezed.tolist())
489-
# tiktoken does not accept tensor inputs
490-
decoding_list = token.tolist()
491-
logger.info(
492-
f"{color.red} decoding token {token=}{type(token)=}, {decoding_list=} {color.reset}"
493-
)
494-
token_str = tokenizer.decode(decoding_list)
495-
logger.info(
496-
f"{color.green} prefile response ===>>>> token_str = {token_str}{color.reset}"
497-
)
498-
tp_rank = dist.get_rank()
499-
# print the token string on tp rank 0
500-
if tp_rank == 0:
501-
logger.info(
502-
f"{color.green} responses ====>>>> "
503-
f"{color.blue} {token_str} {color.reset}"
504-
)
505492

506493
# Decode the output -- first generated token
507494
if pp_rank == last_pp_rank:
508495
logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}")
509496
new_token = _batch_decode_next_tokens(output, prompt_lengths)
510497
res.append(new_token)
511498
if not args.disable_in_flight_decode:
512-
decode_in_flight(new_token)
499+
_decode_in_flight(new_token, tokenizer, tp_rank)
513500

514501
# seqlen = 1 now
515502
seqlen_decode = 1
@@ -559,11 +546,11 @@ def decode_in_flight(token):
559546

560547
# Decode the output
561548
if pp_rank == last_pp_rank:
562-
logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
549+
# logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
563550
new_token = _batch_decode_next_tokens(output, prompt_lengths, step)
564551
res.append(new_token)
565552
if not args.disable_in_flight_decode:
566-
decode_in_flight(new_token)
553+
_decode_in_flight(new_token, tokenizer, tp_rank)
567554

568555
# Increment input position
569556
input_pos += 1

0 commit comments

Comments
 (0)