Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6fd90bc

Browse files
authored
[Distributed] Make decode in flight optional (#1180)
* Replace total_prompts with batch_size * Make in-flight decoding optional * Add back prompt print
1 parent a1a4682 commit 6fd90bc

File tree

1 file changed

+56
-62
lines changed

1 file changed

+56
-62
lines changed

dist_run.py

Lines changed: 56 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -187,30 +187,23 @@ def _create_padded_prompts(
187187

188188
def _batch_decode_next_tokens(
189189
output: torch.Tensor,
190-
tokenizer,
191-
prompt_lengths: Optional[List[int]] = None,
192-
) -> List[Tuple[int, str]]:
190+
pos: int,
191+
) -> torch.Tensor:
193192
"""
194193
Decode the next token for each prompt in the batch.
194+
Args:
195+
output (torch.Tensor): The output tensor to decode.
196+
pos: the position of the `output` to decode in the sequence length dimension.
195197
196198
Returns:
197-
List[Tuple[int, str]]: List of tuples containing the next token id and its
198-
decoded string for each prompt in the batch.
199+
Decoded token ids.
199200
"""
200-
batch_size = output.shape[0]
201-
results = []
202-
203-
for i in range(batch_size):
204-
pos = prompt_lengths[i] - 1 if prompt_lengths is not None else 0
205-
next_token_logits = output[i, pos, :]
206-
207-
# Argmax (deterministic) TODO: add temperature
208-
next_token = torch.argmax(next_token_logits, dim=-1)
209-
210-
next_token_decoded = tokenizer.decode([next_token.item()])
211-
results.append((next_token.item(), next_token_decoded))
212-
213-
return results
201+
# Take the next token logits for each prompt
202+
next_token_logits = output[:, pos, :]
203+
# Argmax (deterministic) TODO: add temperature
204+
next_token = torch.argmax(next_token_logits, dim=-1)
205+
# Token ids in int tensor form
206+
return next_token
214207

215208

216209
def _update_padded_sequence(
@@ -399,11 +392,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
399392
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
400393

401394
# New token generated each iteration
402-
total_prompts = len(prompt_lengths)
403-
# need a new token dimension (row) for each prompt in the batch
404-
new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64)
405-
res = [[] for _ in range(total_prompts)]
406-
num_tokens = 40
395+
# need a row dimension for each prompt in the batch
396+
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
397+
# Store the generated tokens
398+
res = []
407399

408400
# Prefill phase
409401
# Run context input through pipeline
@@ -422,23 +414,24 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
422414
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
423415
)
424416

425-
# Decode the output -- first generated token
426-
if pp_rank == last_pp_rank:
427-
decode_results = _batch_decode_next_tokens(
428-
output=output,
429-
tokenizer=tokenizer,
430-
prompt_lengths=prompt_lengths,
431-
)
432-
for i in range(len(decode_results)):
433-
new_token[i, 0] = torch.tensor(
434-
[decode_results[i][0]], device=device
435-
) # token_id in int form
417+
# Decode token id into string and print it
418+
def decode_in_flight(token):
419+
# Make a 2D tensor with ids on row dimension
420+
unsqueezed = torch.unsqueeze(token, 1)
421+
token_str = tokenizer.decode(unsqueezed.tolist())
436422
if tp_rank == 0:
437423
logger.info(
438-
f"{color.green} {'* Prefill *'} "
439-
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
424+
f"{color.green} responses ====>>>> "
425+
f"{color.blue} {token_str} {color.reset}"
440426
)
441427

428+
# Decode the output -- first generated token
429+
if pp_rank == last_pp_rank:
430+
new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1)
431+
res.append(new_token)
432+
if not args.disable_in_flight_decode:
433+
decode_in_flight(new_token)
434+
442435
# seqlen = 1 now
443436
seqlen_decode = 1
444437
input_pos = torch.tensor([prompt_lengths[0]], device=device)
@@ -460,7 +453,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
460453

461454
# Decoding
462455
with torch.no_grad(), CUDATrackTime() as timer:
463-
for step in range(num_tokens - 1):
456+
for step in range(args.ntokens - 1):
464457
kwargs = {"input_pos": input_pos, "cache_lane": lane}
465458
# sendrecv between last and first ranks, only if:
466459
# first_pp_rank != last_pp_rank.
@@ -487,21 +480,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487480

488481
# Decode the output
489482
if pp_rank == last_pp_rank:
490-
decode_results = _batch_decode_next_tokens(
491-
output=output, tokenizer=tokenizer
492-
)
493-
if tp_rank == 0:
494-
logger.info(
495-
f"{color.green} {'* Decode *'} "
496-
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
497-
)
498-
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
499-
for i in range(len(decode_results)):
500-
res[i].append(decode_results[i][1])
501-
new_token[i, 0] = torch.tensor(
502-
[decode_results[i][0]], device=device
503-
) # decode_results[i][0]
483+
new_token = _batch_decode_next_tokens(output, 0)
484+
res.append(new_token)
485+
if not args.disable_in_flight_decode:
486+
decode_in_flight(new_token)
504487

488+
# Increment input position
505489
input_pos += 1
506490

507491
logger.info(
@@ -512,21 +496,19 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
512496

513497
# output formatted response via last pp group and tp rank 0
514498
if pp_rank == last_pp_rank and tp_rank == 0:
515-
for i in range(len(prompt_lengths)):
516-
logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}")
517-
518-
# TODO: resolve issue with llama2-7b-chat model and "".join
519-
if model_name != "llama2-7b-chat":
520-
formatted_response = "".join(res[i])
521-
else:
522-
formatted_response = " ".join(res[i])
523-
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n")
499+
# `res` is a list of tensors, each being a batch of generated token ids
500+
res = torch.stack(res, dim=1)
501+
res_list = res.tolist()
502+
response = tokenizer.decode(res_list)
503+
for i in range(len(response)):
504+
logger.info(f"Prompt: {color.green}{prompt[i]} {color.reset}")
505+
logger.info(f"Response: {color.red}{response[i]} {color.reset}")
524506

525507
# Cleanup
508+
_cleanup()
526509
logger.info(
527510
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
528511
)
529-
_cleanup()
530512

531513

532514
if __name__ == "__main__":
@@ -538,6 +520,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
538520
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
539521
)
540522
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
523+
parser.add_argument(
524+
"--ntokens",
525+
type=int,
526+
default=40,
527+
help="Number of tokens to generate",
528+
)
529+
parser.add_argument(
530+
"--disable-in-flight-decode",
531+
action="store_true",
532+
default=False,
533+
help="Whether to decode token into string in flight",
534+
)
541535
args = parser.parse_args()
542536

543537
main(args)

0 commit comments

Comments
 (0)