Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@
# Using model name to identify the model to load, for example "llama2-7b-chat".
# You can change it to other values listed below.
# For details on the name-to-distribution mapping, see README.md or models.json.

# Name : HF distribution name, dtype, and model dimension
NAME_TO_DISTRIBUTION_AND_DTYPE = {
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16, 4096),
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16, 4096),
"llama3-70b": ("meta-llama/Meta-Llama-3-70B-Instruct", torch.bfloat16, 8192),
}


Expand Down Expand Up @@ -314,8 +317,12 @@ def main(args):
gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
logger.info(f"Using model weights from {distribution} and dtype {model_dtype}")
distribution, model_dtype, model_dimension = NAME_TO_DISTRIBUTION_AND_DTYPE[
model_name
]
logger.info(
f"Using model weights from {distribution}, dtype {model_dtype} and model dimension {model_dimension}"
)

# Model-level config
model_config = ModelArgs.from_name(distribution)
Expand All @@ -338,6 +345,7 @@ def main(args):

# Tensor parallel is enabled in this program
tp_degree = world_size // pp_degree
logger.info(f"Using TP degree {tp_degree} and PP degree {pp_degree}")

# Create device mesh
mesh_dimensions = (pp_degree, tp_degree)
Expand Down Expand Up @@ -388,7 +396,6 @@ def main(args):
# sense. Thus it is interchangeable with micro-batch size below.
batch_size = len(prompt)
seqlen_prefill = 1024 # sequence length
dim = 4096 # embedding dimension

# Setup KV caches (after model distribution)
# The number of cache lanes is the same as the maximum number of
Expand Down Expand Up @@ -419,7 +426,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
0, config.vocab_size, (batch_size, seqlen), device=device
)
activation = torch.rand(
batch_size, seqlen, dim, device=device, dtype=model_dtype
batch_size, seqlen, model_dimension, device=device, dtype=model_dtype
)
logits = torch.rand(
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
Expand Down
4 changes: 2 additions & 2 deletions torchchat/distributed/force_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct")
print("Model weights and tokenizer downloaded")
Loading