Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
get_num_params,
GPUMemoryMonitor,
)
from distributed.verification_utils import find_cpu_tensors
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
from torchchat.model import ModelArgs, Transformer
Expand Down Expand Up @@ -219,10 +218,9 @@ def _update_padded_sequence(
new_token: torch.Tensor,
prompt_lengths: List[int],
) -> None:
# TODO: this is a hacky way to update the padded sequence: when there is
# more than one prompt, the for loop and the assignment is incompatible.
for i in range(len(prompt_lengths)):
padded_sequence[i, prompt_lengths[i]] = new_token
padded_sequence[i, prompt_lengths[i]] = new_token[i, 0]
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")


def _cleanup():
Expand All @@ -242,7 +240,7 @@ def main(args):
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")

config = ModelArgs.from_name(distribution).transformer_args['text']
config = ModelArgs.from_name(distribution).transformer_args["text"]
logger.info(f"Chat Model Config: {config}")

tokenizer = _build_chat_tokenizer(model_name)
Expand Down Expand Up @@ -295,7 +293,7 @@ def main(args):
logger.info(f"Model: {model}")

mbs = 1 # number of micro-batches
mb_size = 1 # micro-batch size
mb_size = 5 # micro-batch size
batch_size = mbs * mb_size # total batch size

seqlen = 4096 # sequence length
Expand Down Expand Up @@ -343,6 +341,10 @@ def main(args):

prompt = [
"What is snow?",
"Where does Santa Claus live?",
"What is PyTorch?",
"Write a poem about the beauty of the night sky.",
"What is the capital of France, Germany and Switzerland?",
]

"""
Expand All @@ -366,17 +368,23 @@ def main(args):

start_pos = 0

# pipeline comms setup
first_pp_rank = 0
last_pp_rank = pp_group_size - 1

# Need these global ids due to the API definition of dist.send and recv
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)

# encode the prompt
input_ids = _encode_strings(
prompt, tokenizer, bos=True, device=device, dtype=torch.int64
)
logger.info(f"{input_ids[0:8]=}")

# create a padded tensor for the input prompt
padded_sequence, prompt_lengths = _create_padded_prompts(
input_ids, tokenizer, seqlen, start_pos, device
)
logger.info(f"{prompt_lengths=}")

# create schedule
schedule = ScheduleGPipe(stage, mbs)
Expand All @@ -389,8 +397,10 @@ def main(args):
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)

# New token generated each iteration
new_token = torch.zeros(1, device=device, dtype=torch.int64)
res = []
total_prompts = len(prompt_lengths)
# need a new token dimension (row) for each prompt in the batch
new_token = torch.zeros(total_prompts, 1, device=device, dtype=torch.int64)
res = [[] for _ in range(total_prompts)]
num_tokens = 40

# Decoding
Expand All @@ -415,8 +425,11 @@ def main(args):
f"responses ====>>>> {color.blue} {decode_results=}{color.reset}"
)
# decode results returns both token_id (int) and token_str (readable), hence [0] and [1]
new_token = torch.tensor([decode_results[0][0]], device=device)
res.append(decode_results[0][1])
for i in range(len(decode_results)):
res[i].append(decode_results[i][1])
new_token[i, 0] = torch.tensor(
[decode_results[i][0]], device=device
) # decode_results[i][0]

# sendrecv between last and first ranks, only if:
# first_pp_rank != last_pp_rank.
Expand All @@ -435,20 +448,27 @@ def main(args):

# Update input sequence with new token
if pp_rank == first_pp_rank:
_update_padded_sequence(
padded_sequence, new_token, prompt_lengths
)
_update_padded_sequence(padded_sequence, new_token, prompt_lengths)

# increment prompt lengths for next token
for i in range(len(prompt_lengths)):
prompt_lengths[i] += 1

# Display the decoding results

# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_rank and tp_rank == 0:
logger.info(f"Prompt:{color.green} {prompt[0]} {color.reset}")
formatted_response = " ".join(res)
logger.info(f"$$$$$$ {color.blue}{formatted_response} {color.reset} $$$$$")
for i in range(len(prompt_lengths)):
logger.info(f"\nPrompt:{color.green} {prompt[i]} {color.reset}")

# TODO: resolve issue with llama2-7b-chat model and "".join
if model_name != "llama2-7b-chat":
formatted_response = "".join(res[i])
else:
formatted_response = " ".join(res[i])
logger.info(f"$$ {color.red}{formatted_response} {color.reset} $$\n")

# Cleanup
logger.info(
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
)
Expand All @@ -457,7 +477,12 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys())
parser.add_argument(
"model_name",
type=str,
help="Name of the model to load",
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
)
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion run_dist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ NGPU=${NGPU:-"4"}
LOG_RANK=${LOG_RANK:-0,1,2,3}
torchrun --nproc-per-node=$NGPU --master_port=$PORT \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
dist_run.py
dist_run.py --pp 2 llama3
Loading