Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2e3d1dc

Browse files
committed
Merge branch 'main' into unify-constuct-model
2 parents 7ec018a + 94e56f1 commit 2e3d1dc

File tree

2 files changed

+50
-22
lines changed

2 files changed

+50
-22
lines changed

dist_run.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
get_num_params,
3232
GPUMemoryMonitor,
3333
)
34-
from distributed.verification_utils import find_cpu_tensors
3534
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
3635
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
3736
from torchchat.model import ModelArgs, Transformer
@@ -219,10 +218,9 @@ def _update_padded_sequence(
219218
new_token: torch.Tensor,
220219
prompt_lengths: List[int],
221220
) -> None:
222-
# TODO: this is a hacky way to update the padded sequence: when there is
223-
# more than one prompt, the for loop and the assignment is incompatible.
224221
for i in range(len(prompt_lengths)):
225-
padded_sequence[i, prompt_lengths[i]] = new_token
222+
padded_sequence[i, prompt_lengths[i]] = new_token[i, 0]
223+
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
226224

227225

228226
def _cleanup():
@@ -242,7 +240,7 @@ def main(args):
242240
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
243241
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
244242

245-
config = ModelArgs.from_name(distribution).transformer_args['text']
243+
config = ModelArgs.from_name(distribution).transformer_args["text"]
246244
logger.info(f"Chat Model Config: {config}")
247245

248246
tokenizer = _build_chat_tokenizer(model_name)
@@ -295,7 +293,7 @@ def main(args):
295293
logger.info(f"Model: {model}")
296294

297295
mbs = 1 # number of micro-batches
298-
mb_size = 1 # micro-batch size
296+
mb_size = 5 # micro-batch size
299297
batch_size = mbs * mb_size # total batch size
300298

301299
seqlen = 4096 # sequence length
@@ -343,6 +341,10 @@ def main(args):
343341

344342
prompt = [
345343
"What is snow?",
344+
"Where does Santa Claus live?",
345+
"What is PyTorch?",
346+
"Write a poem about the beauty of the night sky.",
347+
"What is the capital of France, Germany and Switzerland?",
346348
]
347349

348350
"""
@@ -366,28 +368,39 @@ def main(args):
366368

367369
start_pos = 0
368370

371+
# pipeline comms setup
372+
first_pp_rank = 0
373+
last_pp_rank = pp_group_size - 1
374+
375+
# Need these global ids due to the API definition of dist.send and recv
376+
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
377+
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
378+
369379
# encode the prompt
370380
input_ids = _encode_strings(
371381
prompt, tokenizer, bos=True, device=device, dtype=torch.int64
372382
)
373-
logger.info(f"{input_ids[0:8]=}")
374383

375384
# create a padded tensor for the input prompt
376385
padded_sequence, prompt_lengths = _create_padded_prompts(
377386
input_ids, tokenizer, seqlen, start_pos, device
378387
)
379-
logger.info(f"{prompt_lengths=}")
380388

381389
# create schedule
382390
schedule = ScheduleGPipe(stage, mbs)
383391

384392
# with CUDATrackTime() as timer:
385393
first_pp_rank = 0
386394
last_pp_rank = pp_group_size - 1
395+
# Need these global ids due to the API definition of dist.send and recv
396+
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
397+
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
387398

388399
# New token generated each iteration
389-
new_token = torch.zeros(1, device=device, dtype=torch.int64)
390-
res = []
400+
total_prompts = len(prompt_lengths)
401+
# need a new token dimension (row) for each prompt in the batch
402+
new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64)
403+
res = [[] for _ in range(total_prompts)]
391404
num_tokens = 40
392405

393406
# Decoding
@@ -412,40 +425,50 @@ def main(args):
412425
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
413426
)
414427
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
415-
new_token = torch.tensor([decode_results[0][0]], device=device)
416-
res.append(decode_results[0][1])
428+
for i in range(len(decode_results)):
429+
res[i].append(decode_results[i][1])
430+
new_token[i, 0] = torch.tensor(
431+
[decode_results[i][0]], device=device
432+
) # decode_results[i][0]
417433

418434
# sendrecv between last and first ranks, only if:
419435
# first_pp_rank != last_pp_rank.
420436
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
421437
dist.send(
422438
new_token,
423-
dst=dist.get_global_rank(pp_group, first_pp_rank),
439+
dst=first_pp_rank_global_id,
424440
group=pp_group,
425441
)
426442
elif pp_rank == first_pp_rank and pp_rank != last_pp_rank:
427443
dist.recv(
428444
new_token,
429-
src=dist.get_global_rank(pp_group, last_pp_rank),
445+
src=last_pp_rank_global_id,
430446
group=pp_group,
431447
)
432448

433449
# Update input sequence with new token
434450
if pp_rank == first_pp_rank:
435-
_update_padded_sequence(
436-
padded_sequence, new_token, prompt_lengths
437-
)
451+
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)
438452

439453
# increment prompt lengths for next token
440454
for i in range(len(prompt_lengths)):
441455
prompt_lengths[i] += 1
442456

457+
# Display the decoding results
458+
443459
# output formatted response via last pp group and tp rank 0
444460
if pp_rank == last_pp_rank and tp_rank == 0:
445-
logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}")
446-
formatted_response = " ".join(res)
447-
logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$")
461+
for i in range(len(prompt_lengths)):
462+
logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}")
448463

464+
# TODO: resolve issue with llama2-7b-chat model and "".join
465+
if model_name != "llama2-7b-chat":
466+
formatted_response = "".join(res[i])
467+
else:
468+
formatted_response = " ".join(res[i])
469+
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n")
470+
471+
# Cleanup
449472
logger.info(
450473
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
451474
)
@@ -454,7 +477,12 @@ def main(args):
454477

455478
if __name__ == "__main__":
456479
parser = argparse.ArgumentParser()
457-
parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys())
480+
parser.add_argument(
481+
"model_name",
482+
type=str,
483+
help="Name of the model to load",
484+
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
485+
)
458486
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
459487
args = parser.parse_args()
460488

run_dist.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ NGPU=${NGPU:-"4"}
44
LOG_RANK=${LOG_RANK:-0,1,2,3}
55
torchrun --nproc-per-node=$NGPU --master_port=$PORT \
66
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
7-
dist_run.py
7+
dist_run.py --pp 2 llama3

0 commit comments

Comments
 (0)