Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 195 additions & 43 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,23 @@
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
from typing import Any, Dict, List, Optional, Tuple

# Run command:
# torchrun --nproc-per-node 4 dist_run.py
import torch
import torch.distributed as dist
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe


from distributed.logging_utils import SingletonLogger

# TODO - these are not distributed specific, consider moving to new package
from distributed.safetensor_utils import (
get_hf_config_file,
get_hf_weight_map_and_path,
load_safetensor_weights,
)

from distributed.utils import (
Color as color,
GPUMemoryMonitor,
get_module_size,
get_num_params,
bytes_to_readable,
TrackTime,
CUDATrackTime,
)

from distributed.safetensor_utils import (get_hf_config_file,
get_hf_weight_map_and_path,
load_safetensor_weights)
from distributed.utils import Color as color
from distributed.utils import (GPUMemoryMonitor, TrackTime, CUDATrackTime,
bytes_to_readable, get_module_size,
get_num_params)
from distributed.verification_utils import find_cpu_tensors
from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer
from torchchat.model import ModelArgs, Transformer
Expand All @@ -52,7 +41,8 @@

logger = SingletonLogger.get_logger()

MODEL_NAME = "Transformer-2-7b-chat-hf"
MODEL_NAME = "Meta-Llama-3-8B"

NAME_TO_HF_MODEL_ID_AND_DTYPE = {
"Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
"Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
Expand Down Expand Up @@ -122,6 +112,88 @@ def _load_model_weights(stage_module, hf_model_name, device, model_config):
raise ValueError(f"Missing {num_missing_weights} weights")


def _encode_strings(
strings: List[str],
tokenizer,
bos: bool = True,
device: str = "cuda",
dtype=torch.int64,
) -> List[torch.Tensor]:
"""Encode a list of prompt strings into a list of tensor token ids."""
encoded_list = []
for string in strings:
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device))
return encoded_list


def _create_padded_prompts(
input_ids_list: List[torch.Tensor],
tokenizer,
seqlen: int,
start_pos: int,
device: str,
pad_token_id: Optional[int] = None,
) -> Tuple[torch.Tensor, List[int]]:
"""
Create a padded tensor for multiple encoded input prompts.

Returns:
Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths.
"""
pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id()

# Find the maximum prompt length
max_prompt_len = max(ids.size(0) for ids in input_ids_list)

# Calculate the buffer size
max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len))
token_buffer_size = max_prompt_len + max_new_tokens

# Create the padded batch tensor
batch_size = len(input_ids_list)
batch_seq = torch.full(
(batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device
)

prompt_lengths = []
for i, input_ids in enumerate(input_ids_list):
prompt_len = input_ids.size(0)
batch_seq[i, :prompt_len] = input_ids
prompt_lengths.append(prompt_len)

return batch_seq, prompt_lengths


def _batch_decode_next_tokens(
output: torch.Tensor,
prompt_lengths: List[int],
tokenizer,
) -> List[Tuple[int, str]]:
"""
Decode the next token for each prompt in the batch.

Returns:
List[Tuple[int, str]]: List of tuples containing the next token id and its
decoded string for each prompt in the batch.
"""
batch_size = output.shape[0]
results = []

for i in range(batch_size):
next_token_logits = output[i, prompt_lengths[i] - 1, :]

# Argmax (deterministic) TODO: add temperature
next_token = torch.argmax(next_token_logits, dim=-1)

next_token_decoded = tokenizer.decode([next_token.item()])
results.append((next_token.item(), next_token_decoded))

return results


def _cleanup():
dist.barrier()
dist.destroy_process_group()
Expand All @@ -133,8 +205,8 @@ def main():
gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

config = ModelArgs.from_name(MODEL_NAME).transformer_args['text']
logger.info(f"Chat Model Config: {config}")
config = ModelArgs.from_name(MODEL_NAME).transformer_args["text"]
logger.info(f"Chat Model Name: {MODEL_NAME}\nModel Config: {config}")

tokenizer = _build_chat_tokenizer()
logger.info(f"built tokenizer {tokenizer=}")
Expand Down Expand Up @@ -182,11 +254,12 @@ def main():

# Distribute model on TP mesh
model.distribute(tp_mesh)
logger.info(f"Model: {model}")
# logger.info(f"Model: {model}")

mbs = 2 # number of micro-batches
mbs = 1 # number of micro-batches
mb_size = 1 # micro-batch size
batch_size = mbs * mb_size # total batch size

seqlen = 4096 # sequence length
dim = 4096 # embedding dimension
assert seqlen % sp_degree == 0
Expand All @@ -199,7 +272,7 @@ def main():

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with TrackTime("cuda") as timer:
with TrackTime() as timer:
_load_model_weights(model, hf_model_name, device=device, model_config=config)
logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
Expand All @@ -212,9 +285,8 @@ def main():
logger.info(
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n"
)

# Setup input position
# input_pos for prefill: a list of increasing integers from 0 to seqlen
input_pos = torch.arange(seqlen, device=device)
model.setup_input_pos(input_pos)
model.eval()
Expand All @@ -235,30 +307,109 @@ def main():
if len(cpu_tensors) > 0:
raise ValueError("Found cpu tensors in stage")

# TODO: this can likely be removed after we prove out a few more models
# verify dtypes for model - expect all to be model_dtype except for bool causal_mask atm.
# dtype_count, dtype_locations, fp32_locations = record_module_dtypes(stage.submod)
# logger.info(
# f"Stage Dtypes - Found {len(dtype_count)} dtypes: {dtype_count.items()}"
# )
# assert (
# len(dtype_count) == 2
# ), f"Expected 2 dtypes in model after checkpoint loading: {model_dtype} and {torch.bool}"
prompt = [
"What is snow?",
]

'''
"What is the capital of France?",
"What is your name?",
"What is the capital of Japan?",
"When is Christmas?",
"Where does Santa Claus live?",
"What is the capital of the United States?",
"What is the capital of China?",
"What is the capital of Russia?",
"What is PyTorch?",
"What is the capital of India?",
"What is an LLM?",
"What is the capital of Brazil?",
"What is the capital of Mexico?",
"What is the capital of Argentina?",
"What is the capital of Canada?",
]
'''

start_pos = 0

# encode the prompt
input_ids = _encode_strings(
prompt, tokenizer, bos=True, device=device, dtype=torch.int64
)
logger.info(f"{input_ids[0:8]=}")

input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
logger.info(f"Input: {input_ids.dtype=}, {input_ids.shape=}, {input_ids.device=}")
# create a padded tensor for the input prompt
padded_sequence, prompt_lengths = _create_padded_prompts(
input_ids, tokenizer, seqlen, start_pos, device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (1) why padding the input_ids requires the tokenizer again? (2) maybe device can be inferred from input_ids?

logger.info(f"{prompt_lengths=}")
logger.info(f"first prompt {padded_sequence[0, :prompt_lengths[0]+1]=}")
if len(prompt_lengths) > 1:
logger.info(f"second prompt {padded_sequence[1, :prompt_lengths[1]+1]=}")
Comment on lines +390 to +393
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we clean up these logs a little bit? And why do we need the +1?


schedule = ScheduleGPipe(stage, mbs)
logger.info(f"Created schedule: {schedule}")

# with CUDATrackTime() as timer:
first_pp_rank = 0
last_pp_rank = pp_degree - 1

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,], device=device)
x_recv = torch.zeros_like(x)

last_global_rank = world_size - 1

with torch.no_grad(): # .inference_mode():
# for _ in range(1):
# first
if pp_rank == 0:
schedule.step(input_ids)
else:
schedule.step(padded_sequence)
if rank == 0:
dist.recv(padded_sequence, src=last_global_rank, )
logger.info(f"RECEIVING from {last_global_rank=}")
# elif rank == 1:
# dist.recv(x_recv, src=last_global_rank-1, )
# logger.info(f"RECEIVING from {last_global_rank=}")
# logger.info(f"Received x_recv: {x_recv=}")

# elif tp_rank == 1:
# dist.recv(padded_sequence, src=last_global_rank-1, )
# last
elif pp_rank == last_pp_rank:
output = schedule.step()

logger.info(f"SENDING back...from {pp_rank=}")
#if tp_rank == 0:
if rank == world_size-1:
dist.send(output, dst=0, )
#dist.send(x, dst=1, )

# dist.send(x,dst = 1,)
#elif tp_rank==1:
# dist.send(output, dst=1, )
# elif tp_rank == 1:
# dist.send(output, dst=1, )
# middle pp ranks
else:
schedule.step()

if rank==0:
logger.info(f"{color.red} Success! Received output from {last_global_rank} {color.reset}")
logger.info(f"out of loop - Received output: {padded_sequence[4:8]=}") # {padded_sequence[0, :prompt_lengths[0]+1]=}")
if rank ==1:
logger.info(f"{color.red} Success! Received output from {last_global_rank} {color.reset}")
logger.info(f"out of loop Received output: {x_recv=}") # {padded_sequence[0, :prompt_lengths[0]+1]=}")

#logger.info(f"{color.green}Total prefill time: {timer.get_time()} {timer.unit}{color.reset}")
'''
# Decoding
if pp_rank == pp_degree - 1 and tp_rank == 0:
logger.info(f"Output: {output}")
decode_results = _batch_decode_next_tokens(
output=output, prompt_lengths=prompt_lengths, tokenizer=tokenizer
)

logger.info(
f"\n\n{color.green} Prefill responses ====>>>> {color.blue} {decode_results=} \n{color.reset}"
)

# show peak memory stats for this stage
res_mem_gib, res_mem_pct = gpu_memory_monitor.get_peak_stats()
Expand All @@ -269,7 +420,8 @@ def main():
logger.info(
f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}"
)

'''
logger.info(f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}")
_cleanup()


Expand Down
2 changes: 1 addition & 1 deletion run_dist.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export CUDA_VISIBLE_DEVICES=4,5,6,7
#export CUDA_VISIBLE_DEVICES=4,5,6,7
PORT=${1:-29501}
NGPU=${NGPU:-"4"}
LOG_RANK=${LOG_RANK:-0,1,2,3}
Expand Down
Loading