Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 511b5b8

Browse files
committed
enable batch decoding, optimize dst/src creation outside of decoding loop
1 parent 7708646 commit 511b5b8

File tree

2 files changed

+48
-28
lines changed

2 files changed

+48
-28
lines changed

dist_run.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,11 @@ def _update_padded_sequence(
219219
new_token: torch.Tensor,
220220
prompt_lengths: List[int],
221221
) -> None:
222-
# TODO: this is a hacky way to update the padded sequence: when there is
223-
# more than one prompt, the for loop and the assignment is incompatible.
222+
224223
for i in range(len(prompt_lengths)):
225-
padded_sequence[i, prompt_lengths[i]] = new_token
224+
prompt_lengths[i] += 1
225+
padded_sequence[i, prompt_lengths[i] - 1] = new_token[i, 0]
226+
logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
226227

227228

228229
def _cleanup():
@@ -242,7 +243,7 @@ def main(args):
242243
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
243244
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
244245

245-
config = ModelArgs.from_name(distribution).transformer_args['text']
246+
config = ModelArgs.from_name(distribution).transformer_args["text"]
246247
logger.info(f"Chat Model Config: {config}")
247248

248249
tokenizer = _build_chat_tokenizer(model_name)
@@ -295,7 +296,7 @@ def main(args):
295296
logger.info(f"Model: {model}")
296297

297298
mbs = 1 # number of micro-batches
298-
mb_size = 1 # micro-batch size
299+
mb_size = 2 # micro-batch size
299300
batch_size = mbs * mb_size # total batch size
300301

301302
seqlen = 4096 # sequence length
@@ -343,6 +344,7 @@ def main(args):
343344

344345
prompt = [
345346
"What is snow?",
347+
"Where does Santa Claus live?",
346348
]
347349

348350
"""
@@ -366,29 +368,36 @@ def main(args):
366368

367369
start_pos = 0
368370

371+
# pipeline comms setup
372+
first_pp_rank = 0
373+
last_pp_rank = pp_group_size - 1
374+
375+
send_destination = dist.get_global_rank(pp_group, first_pp_rank)
376+
recv_source = dist.get_global_rank(pp_group, last_pp_rank)
377+
369378
# encode the prompt
370379
input_ids = _encode_strings(
371380
prompt, tokenizer, bos=True, device=device, dtype=torch.int64
372381
)
373-
logger.info(f"{input_ids[0:8]=}")
382+
logger.info(f"{input_ids[0][0:8]=}")
374383

375384
# create a padded tensor for the input prompt
376385
padded_sequence, prompt_lengths = _create_padded_prompts(
377386
input_ids, tokenizer, seqlen, start_pos, device
378387
)
379-
logger.info(f"{prompt_lengths=}")
388+
logger.info(f"length of each prompt in the batch: {prompt_lengths=}")
380389

381390
# create schedule
382391
schedule = ScheduleGPipe(stage, mbs)
383392

384393
# with CUDATrackTime() as timer:
385-
first_pp_rank = 0
386-
last_pp_rank = pp_group_size - 1
387394

388395
# New token generated each iteration
389-
new_token = torch.zeros(1, device=device, dtype=torch.int64)
390-
res = []
391-
num_tokens = 40
396+
total_prompts = len(prompt_lengths)
397+
# need a new token dimension (row) for each prompt in the batch
398+
new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64)
399+
res = [[] for _ in range(total_prompts)]
400+
num_tokens = 20
392401

393402
# Decoding
394403
with torch.no_grad():
@@ -412,39 +421,45 @@ def main(args):
412421
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
413422
)
414423
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
415-
new_token = torch.tensor([decode_results[0][0]], device=device)
416-
res.append(decode_results[0][1])
424+
for i in range(len(decode_results)):
425+
res[i].append(decode_results[i][1])
426+
new_token[i, 0] = torch.tensor(
427+
[decode_results[i][0]], device=device
428+
) # decode_results[i][0]
429+
430+
# increment prompt lengths for next token
431+
for i in range(len(prompt_lengths)):
432+
prompt_lengths[i] += 1
417433

418434
# sendrecv between last and first ranks, only if:
419435
# first_pp_rank != last_pp_rank.
420436
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
421437
dist.send(
422438
new_token,
423-
dst=dist.get_global_rank(pp_group, first_pp_rank),
439+
dst=send_destination,
424440
group=pp_group,
425441
)
426442
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
427443
dist.recv(
428444
new_token,
429-
src=dist.get_global_rank(pp_group, last_pp_rank),
445+
src=recv_source,
430446
group=pp_group,
431447
)
432448

433449
# Update input sequence with new token
434450
if pp_rank == first_pp_rank:
435-
_update_padded_sequence(
436-
padded_sequence, new_token, prompt_lengths
437-
)
438-
439-
# increment prompt lengths for next token
440-
for i in range(len(prompt_lengths)):
441-
prompt_lengths[i] += 1
451+
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)
452+
for i in range(len(prompt_lengths)):
453+
logger.info(
454+
f"next submission: {padded_sequence[i, prompt_lengths[i]-4:prompt_lengths[i]+4]}"
455+
)
442456

443457
# output formatted response via last pp group and tp rank 0
444458
if pp_rank == last_pp_rank and tp_rank == 0:
445-
logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}")
446-
formatted_response = " ".join(res)
447-
logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$")
459+
for i in range(len(prompt_lengths)):
460+
logger.info(f"Prompt:{color.green} {prompt[i]} {color.reset}")
461+
formatted_response = "".join(res[i])
462+
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$")
448463

449464
logger.info(
450465
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
@@ -454,7 +469,12 @@ def main(args):
454469

455470
if __name__ == "__main__":
456471
parser = argparse.ArgumentParser()
457-
parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys())
472+
parser.add_argument(
473+
"model_name",
474+
type=str,
475+
help="Name of the model to load",
476+
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
477+
)
458478
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
459479
args = parser.parse_args()
460480

run_dist.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ NGPU=${NGPU:-"4"}
44
LOG_RANK=${LOG_RANK:-0,1,2,3}
55
torchrun --nproc-per-node=$NGPU --master_port=$PORT \
66
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
7-
dist_run.py
7+
dist_run.py --pp 2 llama3

0 commit comments

Comments
 (0)