Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
from typing import Any, Dict, Optional

# Run command:
# torchrun --nproc-per-node 4 dist_run.py
Expand Down Expand Up @@ -52,10 +53,12 @@

logger = SingletonLogger.get_logger()

MODEL_NAME = "Transformer-2-7b-chat-hf"
NAME_TO_HF_MODEL_ID_AND_DTYPE = {
"Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
"Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
# Using model name to identify the model to load, for example "llama2-7b-chat".
# You can change it to other values listed below.
# For details on the name-to-distribution mapping, see README.md or models.json.
NAME_TO_DISTRIBUTION_AND_DTYPE = {
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
}
CACHE_PRECISION = torch.bfloat16

Expand All @@ -78,8 +81,19 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:


def _build_chat_tokenizer(
model_base_name: str = "llama3",
model_name: str,
model_base_name: Optional[str] = None,
) -> SentencePieceProcessor | TiktokenTokenizer:
"""Builds a tokenizer for the given model name."""
# Try to infer the model base name from the model name:
# e.g. "llama2-7b-chat" -> "llama2"
if model_base_name is None:
model_base_name = model_name.split("-")[0]
logger.info(
f"Using model base name '{model_base_name}' to build tokenizer. "
"If not found, please specify it using the `model_base_name` argument."
)

# Create base args for tokenizer
default_model_dir = Path(
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
Expand All @@ -100,12 +114,12 @@ def _build_chat_tokenizer(
return tokenizer


def _load_model_weights(stage_module, hf_model_name, device, model_config):
def _load_model_weights(stage_module, distribution, device, model_config):
"""Load the weights from the safetensor file(s) into the model stage.
Model config is needed b/c we permute wq and wk weights based on attn heads.
"""

weight_map, weight_path, key_map = get_hf_weight_map_and_path(hf_model_name)
weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)

num_loaded_weights, num_missing_weights = load_safetensor_weights(
stage_module,
Expand All @@ -127,32 +141,31 @@ def _cleanup():
dist.destroy_process_group()


def main():
def main(args):
model_name = args.model_name
pp_degree = args.pp

rank, world_size = _init_distributed()

gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

config = ModelArgs.from_name(MODEL_NAME).transformer_args['text']
logger.info(f"Chat Model Config: {config}")
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")

tokenizer = _build_chat_tokenizer()
logger.info(f"built tokenizer {tokenizer=}")
config = ModelArgs.from_name(distribution).transformer_args['text']
logger.info(f"Chat Model Config: {config}")

hf_model_name, model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE[MODEL_NAME]
logger.info(f"Using HF model weights from {hf_model_name} and dtype {model_dtype}")
tokenizer = _build_chat_tokenizer(model_name)

set_precision(CACHE_PRECISION)
logger.info(f"Using cache precision {CACHE_PRECISION}")

hf_config = get_hf_config_file(hf_model_name)
hf_config = get_hf_config_file(distribution)
if hf_config is None:
raise ValueError(f"Config file not found for model id {hf_model_name}")
logger.info(f"Using HF model weights from {hf_model_name}")
raise ValueError(f"Config file not found for model id {distribution}")

# Assuming 2 pipeline stages, feel free to change this as long as the
# asserts are satisfied
pp_degree = 2
# Validate pipeline degree
assert world_size % pp_degree == 0
assert config.n_layers % pp_degree == 0

Expand Down Expand Up @@ -182,7 +195,8 @@ def main():

# Distribute model on TP mesh
model.distribute(tp_mesh)
logger.info(f"Model: {model}")
if rank == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: already filtering based on ranks in the dist_run.sh launcher rather than having express rank based logging in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is a hard choice.. I don't know if regular users use filter (I never remember that option's name) and want to have a minimal command line work.

logger.info(f"Model: {model}")

mbs = 2 # number of micro-batches
mb_size = 1 # micro-batch size
Expand All @@ -200,7 +214,7 @@ def main():
# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with TrackTime("cuda") as timer:
_load_model_weights(model, hf_model_name, device=device, model_config=config)
_load_model_weights(model, distribution, device=device, model_config=config)
logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
)
Expand Down Expand Up @@ -253,7 +267,7 @@ def main():

with torch.no_grad(): # .inference_mode():
if pp_rank == 0:
schedule.step(input_ids)
output = schedule.step(input_ids)
else:
output = schedule.step()

Expand All @@ -274,4 +288,9 @@ def main():


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys())
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
args = parser.parse_args()

main(args)
Loading