Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7a4f0d1

Browse files
kwen2501Jack-Khuu
andauthored
[Distributed] Accept model name (#1148)
Co-authored-by: Jack-Khuu <[email protected]>
1 parent 26c1d8b commit 7a4f0d1

File tree

1 file changed

+44
-25
lines changed

1 file changed

+44
-25
lines changed

dist_run.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import argparse
78
import os
89
from pathlib import Path
910
from types import SimpleNamespace
10-
from typing import Any, Dict
11+
from typing import Any, Dict, Optional
1112

1213
# Run command:
1314
# torchrun --nproc-per-node 4 dist_run.py
@@ -52,10 +53,12 @@
5253

5354
logger = SingletonLogger.get_logger()
5455

55-
MODEL_NAME = "Transformer-2-7b-chat-hf"
56-
NAME_TO_HF_MODEL_ID_AND_DTYPE = {
57-
"Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
58-
"Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
56+
# Using model name to identify the model to load, for example "llama2-7b-chat".
57+
# You can change it to other values listed below.
58+
# For details on the name-to-distribution mapping, see README.md or models.json.
59+
NAME_TO_DISTRIBUTION_AND_DTYPE = {
60+
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
61+
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
5962
}
6063
CACHE_PRECISION = torch.bfloat16
6164

@@ -78,8 +81,19 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
7881

7982

8083
def _build_chat_tokenizer(
81-
model_base_name: str = "llama3",
84+
model_name: str,
85+
model_base_name: Optional[str] = None,
8286
) -> SentencePieceProcessor | TiktokenTokenizer:
87+
"""Builds a tokenizer for the given model name."""
88+
# Try to infer the model base name from the model name:
89+
# e.g. "llama2-7b-chat" -> "llama2"
90+
if model_base_name is None:
91+
model_base_name = model_name.split("-")[0]
92+
logger.info(
93+
f"Using model base name '{model_base_name}' to build tokenizer. "
94+
"If not found, please specify it using the `model_base_name` argument."
95+
)
96+
8397
# Create base args for tokenizer
8498
default_model_dir = Path(
8599
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
@@ -100,12 +114,12 @@ def _build_chat_tokenizer(
100114
return tokenizer
101115

102116

103-
def _load_model_weights(stage_module, hf_model_name, device, model_config):
117+
def _load_model_weights(stage_module, distribution, device, model_config):
104118
"""Load the weights from the safetensor file(s) into the model stage.
105119
Model config is needed b/c we permute wq and wk weights based on attn heads.
106120
"""
107121

108-
weight_map, weight_path, key_map = get_hf_weight_map_and_path(hf_model_name)
122+
weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)
109123

110124
num_loaded_weights, num_missing_weights = load_safetensor_weights(
111125
stage_module,
@@ -127,32 +141,31 @@ def _cleanup():
127141
dist.destroy_process_group()
128142

129143

130-
def main():
144+
def main(args):
145+
model_name = args.model_name
146+
pp_degree = args.pp
147+
131148
rank, world_size = _init_distributed()
132149

133150
gpu_memory_monitor = GPUMemoryMonitor("cuda")
134151
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
135152

136-
config = ModelArgs.from_name(MODEL_NAME).transformer_args['text']
137-
logger.info(f"Chat Model Config: {config}")
153+
distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
154+
logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}")
138155

139-
tokenizer = _build_chat_tokenizer()
140-
logger.info(f"built tokenizer {tokenizer=}")
156+
config = ModelArgs.from_name(distribution).transformer_args['text']
157+
logger.info(f"Chat Model Config: {config}")
141158

142-
hf_model_name, model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE[MODEL_NAME]
143-
logger.info(f"Using HF model weights from {hf_model_name} and dtype {model_dtype}")
159+
tokenizer = _build_chat_tokenizer(model_name)
144160

145161
set_precision(CACHE_PRECISION)
146162
logger.info(f"Using cache precision {CACHE_PRECISION}")
147163

148-
hf_config = get_hf_config_file(hf_model_name)
164+
hf_config = get_hf_config_file(distribution)
149165
if hf_config is None:
150-
raise ValueError(f"Config file not found for model id {hf_model_name}")
151-
logger.info(f"Using HF model weights from {hf_model_name}")
166+
raise ValueError(f"Config file not found for model id {distribution}")
152167

153-
# Assuming 2 pipeline stages, feel free to change this as long as the
154-
# asserts are satisfied
155-
pp_degree = 2
168+
# Validate pipeline degree
156169
assert world_size % pp_degree == 0
157170
assert config.n_layers % pp_degree == 0
158171

@@ -182,7 +195,8 @@ def main():
182195

183196
# Distribute model on TP mesh
184197
model.distribute(tp_mesh)
185-
logger.info(f"Model: {model}")
198+
if rank == 0:
199+
logger.info(f"Model: {model}")
186200

187201
mbs = 2 # number of micro-batches
188202
mb_size = 1 # micro-batch size
@@ -200,7 +214,7 @@ def main():
200214
# Load weights
201215
logger.info(f"Loading weights for {pp_rank=} on {device=}")
202216
with TrackTime("cuda") as timer:
203-
_load_model_weights(model, hf_model_name, device=device, model_config=config)
217+
_load_model_weights(model, distribution, device=device, model_config=config)
204218
logger.info(
205219
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
206220
)
@@ -253,7 +267,7 @@ def main():
253267

254268
with torch.no_grad(): # .inference_mode():
255269
if pp_rank == 0:
256-
schedule.step(input_ids)
270+
output = schedule.step(input_ids)
257271
else:
258272
output = schedule.step()
259273

@@ -274,4 +288,9 @@ def main():
274288

275289

276290
if __name__ == "__main__":
277-
main()
291+
parser = argparse.ArgumentParser()
292+
parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys())
293+
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
294+
args = parser.parse_args()
295+
296+
main(args)

0 commit comments

Comments
 (0)