Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 51dc929

Browse files
committed
Make in-flight decoding optional
1 parent c9e6152 commit 51dc929

File tree

1 file changed

+53
-59
lines changed

1 file changed

+53
-59
lines changed

dist_run.py

Lines changed: 53 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -187,30 +187,23 @@ def _create_padded_prompts(
187187

188188
def _batch_decode_next_tokens(
189189
output: torch.Tensor,
190-
tokenizer,
191-
prompt_lengths: Optional[List[int]] = None,
192-
) -> List[Tuple[int, str]]:
190+
pos: int,
191+
) -> torch.Tensor:
193192
"""
194193
Decode the next token for each prompt in the batch.
194+
Args:
195+
output (torch.Tensor): The output tensor to decode.
196+
pos: the position of the `output` to decode in the sequence length dimension.
195197
196198
Returns:
197-
List[Tuple[int, str]]: List of tuples containing the next token id and its
198-
decoded string for each prompt in the batch.
199+
Decoded token ids.
199200
"""
200-
batch_size = output.shape[0]
201-
results = []
202-
203-
for i in range(batch_size):
204-
pos = prompt_lengths[i] - 1 if prompt_lengths is not None else 0
205-
next_token_logits = output[i, pos, :]
206-
207-
# Argmax (deterministic) TODO: add temperature
208-
next_token = torch.argmax(next_token_logits, dim=-1)
209-
210-
next_token_decoded = tokenizer.decode([next_token.item()])
211-
results.append((next_token.item(), next_token_decoded))
212-
213-
return results
201+
# Take the next token logits for each prompt
202+
next_token_logits = output[:, pos, :]
203+
# Argmax (deterministic) TODO: add temperature
204+
next_token = torch.argmax(next_token_logits, dim=-1)
205+
# Token ids in int tensor form
206+
return next_token
214207

215208

216209
def _update_padded_sequence(
@@ -401,8 +394,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
401394
# New token generated each iteration
402395
# need a row dimension for each prompt in the batch
403396
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
404-
res = [[] for _ in range(batch_size)]
405-
num_tokens = 40
397+
# Store the generated tokens
398+
res = []
406399

407400
# Prefill phase
408401
# Run context input through pipeline
@@ -421,23 +414,24 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
421414
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
422415
)
423416

424-
# Decode the output -- first generated token
425-
if pp_rank == last_pp_rank:
426-
decode_results = _batch_decode_next_tokens(
427-
output=output,
428-
tokenizer=tokenizer,
429-
prompt_lengths=prompt_lengths,
430-
)
431-
for i in range(len(decode_results)):
432-
new_token[i, 0] = torch.tensor(
433-
[decode_results[i][0]], device=device
434-
) # token_id in int form
417+
# Decode token id into string and print it
418+
def decode_in_flight(token):
419+
# Make a 2D tensor with ids on row dimension
420+
unsqueezed = torch.unsqueeze(token, 1)
421+
token_str = tokenizer.decode(unsqueezed.tolist())
435422
if tp_rank == 0:
436423
logger.info(
437-
f"{color.green} {'* Prefill *'} "
438-
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
424+
f"{color.green} responses ====>>>> "
425+
f"{color.blue} {token_str} {color.reset}"
439426
)
440427

428+
# Decode the output -- first generated token
429+
if pp_rank == last_pp_rank:
430+
new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1)
431+
res.append(new_token)
432+
if not args.disable_in_flight_decode:
433+
decode_in_flight(new_token)
434+
441435
# seqlen = 1 now
442436
seqlen_decode = 1
443437
input_pos = torch.tensor([prompt_lengths[0]], device=device)
@@ -459,7 +453,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
459453

460454
# Decoding
461455
with torch.no_grad(), CUDATrackTime() as timer:
462-
for step in range(num_tokens - 1):
456+
for step in range(args.ntokens - 1):
463457
kwargs = {"input_pos": input_pos, "cache_lane": lane}
464458
# sendrecv between last and first ranks, only if:
465459
# first_pp_rank != last_pp_rank.
@@ -486,21 +480,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
486480

487481
# Decode the output
488482
if pp_rank == last_pp_rank:
489-
decode_results = _batch_decode_next_tokens(
490-
output=output, tokenizer=tokenizer
491-
)
492-
if tp_rank == 0:
493-
logger.info(
494-
f"{color.green} {'* Decode *'} "
495-
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
496-
)
497-
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
498-
for i in range(len(decode_results)):
499-
res[i].append(decode_results[i][1])
500-
new_token[i, 0] = torch.tensor(
501-
[decode_results[i][0]], device=device
502-
) # decode_results[i][0]
483+
new_token = _batch_decode_next_tokens(output, 0)
484+
res.append(new_token)
485+
if not args.disable_in_flight_decode:
486+
decode_in_flight(new_token)
503487

488+
# Increment input position
504489
input_pos += 1
505490

506491
logger.info(
@@ -511,21 +496,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
511496

512497
# output formatted response via last pp group and tp rank 0
513498
if pp_rank == last_pp_rank and tp_rank == 0:
514-
for i in range(len(prompt_lengths)):
515-
logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}")
516-
517-
# TODO: resolve issue with llama2-7b-chat model and "".join
518-
if model_name != "llama2-7b-chat":
519-
formatted_response = "".join(res[i])
520-
else:
521-
formatted_response = " ".join(res[i])
522-
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n")
499+
# `res` is a list of tensors, each being a batch of generated token ids
500+
res = torch.stack(res, dim=1)
501+
res_list = res.tolist()
502+
response = tokenizer.decode(res_list)
503+
for i in range(len(response)):
504+
logger.info(f"$$ {color.red}{response[i]} {color.reset} $$\n")
523505

524506
# Cleanup
507+
_cleanup()
525508
logger.info(
526509
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
527510
)
528-
_cleanup()
529511

530512

531513
if __name__ == "__main__":
@@ -537,6 +519,18 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
537519
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
538520
)
539521
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
522+
parser.add_argument(
523+
"--ntokens",
524+
type=int,
525+
default=40,
526+
help="Number of tokens to generate",
527+
)
528+
parser.add_argument(
529+
"--disable-in-flight-decode",
530+
action="store_true",
531+
default=False,
532+
help="Whether to decode token into string in flight",
533+
)
540534
args = parser.parse_args()
541535

542536
main(args)

0 commit comments

Comments
 (0)