44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import argparse
78import os
89from pathlib import Path
910from types import SimpleNamespace
10- from typing import Any , Dict
11+ from typing import Any , Dict , Optional
1112
1213# Run command:
1314# torchrun --nproc-per-node 4 dist_run.py
5253
5354logger = SingletonLogger .get_logger ()
5455
55- MODEL_NAME = "Transformer-2-7b-chat-hf"
56- NAME_TO_HF_MODEL_ID_AND_DTYPE = {
57- "Transformer-2-7b-chat-hf" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 ),
58- "Meta-Llama-3-8B" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 ),
56+ # Using model name to identify the model to load, for example "llama2-7b-chat".
57+ # You can change it to other values listed below.
58+ # For details on the name-to-distribution mapping, see README.md or models.json.
59+ NAME_TO_DISTRIBUTION_AND_DTYPE = {
60+ "llama2-7b-chat" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 ),
61+ "llama3" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 ),
5962}
6063CACHE_PRECISION = torch .bfloat16
6164
@@ -78,8 +81,19 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace:
7881
7982
8083def _build_chat_tokenizer (
81- model_base_name : str = "llama3" ,
84+ model_name : str ,
85+ model_base_name : Optional [str ] = None ,
8286) -> SentencePieceProcessor | TiktokenTokenizer :
87+ """Builds a tokenizer for the given model name."""
88+ # Try to infer the model base name from the model name:
89+ # e.g. "llama2-7b-chat" -> "llama2"
90+ if model_base_name is None :
91+ model_base_name = model_name .split ("-" )[0 ]
92+ logger .info (
93+ f"Using model base name '{ model_base_name } ' to build tokenizer. "
94+ "If not found, please specify it using the `model_base_name` argument."
95+ )
96+
8397 # Create base args for tokenizer
8498 default_model_dir = Path (
8599 os .getenv ("TORCHCHAT_MODELDIR" , "~/.torchchat/model-cache" )
@@ -100,12 +114,12 @@ def _build_chat_tokenizer(
100114 return tokenizer
101115
102116
103- def _load_model_weights (stage_module , hf_model_name , device , model_config ):
117+ def _load_model_weights (stage_module , distribution , device , model_config ):
104118 """Load the weights from the safetensor file(s) into the model stage.
105119 Model config is needed b/c we permute wq and wk weights based on attn heads.
106120 """
107121
108- weight_map , weight_path , key_map = get_hf_weight_map_and_path (hf_model_name )
122+ weight_map , weight_path , key_map = get_hf_weight_map_and_path (distribution )
109123
110124 num_loaded_weights , num_missing_weights = load_safetensor_weights (
111125 stage_module ,
@@ -127,32 +141,31 @@ def _cleanup():
127141 dist .destroy_process_group ()
128142
129143
130- def main ():
144+ def main (args ):
145+ model_name = args .model_name
146+ pp_degree = args .pp
147+
131148 rank , world_size = _init_distributed ()
132149
133150 gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
134151 logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
135152
136- config = ModelArgs . from_name ( MODEL_NAME ). transformer_args [ 'text' ]
137- logger .info (f"Chat Model Config: { config } " )
153+ distribution , model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE [ model_name ]
154+ logger .info (f"Using HF model weights from { distribution } and dtype { model_dtype } " )
138155
139- tokenizer = _build_chat_tokenizer ()
140- logger .info (f"built tokenizer { tokenizer = } " )
156+ config = ModelArgs . from_name ( distribution ). transformer_args [ 'text' ]
157+ logger .info (f"Chat Model Config: { config } " )
141158
142- hf_model_name , model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE [MODEL_NAME ]
143- logger .info (f"Using HF model weights from { hf_model_name } and dtype { model_dtype } " )
159+ tokenizer = _build_chat_tokenizer (model_name )
144160
145161 set_precision (CACHE_PRECISION )
146162 logger .info (f"Using cache precision { CACHE_PRECISION } " )
147163
148- hf_config = get_hf_config_file (hf_model_name )
164+ hf_config = get_hf_config_file (distribution )
149165 if hf_config is None :
150- raise ValueError (f"Config file not found for model id { hf_model_name } " )
151- logger .info (f"Using HF model weights from { hf_model_name } " )
166+ raise ValueError (f"Config file not found for model id { distribution } " )
152167
153- # Assuming 2 pipeline stages, feel free to change this as long as the
154- # asserts are satisfied
155- pp_degree = 2
168+ # Validate pipeline degree
156169 assert world_size % pp_degree == 0
157170 assert config .n_layers % pp_degree == 0
158171
@@ -182,7 +195,8 @@ def main():
182195
183196 # Distribute model on TP mesh
184197 model .distribute (tp_mesh )
185- logger .info (f"Model: { model } " )
198+ if rank == 0 :
199+ logger .info (f"Model: { model } " )
186200
187201 mbs = 2 # number of micro-batches
188202 mb_size = 1 # micro-batch size
@@ -200,7 +214,7 @@ def main():
200214 # Load weights
201215 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
202216 with TrackTime ("cuda" ) as timer :
203- _load_model_weights (model , hf_model_name , device = device , model_config = config )
217+ _load_model_weights (model , distribution , device = device , model_config = config )
204218 logger .info (
205219 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
206220 )
@@ -253,7 +267,7 @@ def main():
253267
254268 with torch .no_grad (): # .inference_mode():
255269 if pp_rank == 0 :
256- schedule .step (input_ids )
270+ output = schedule .step (input_ids )
257271 else :
258272 output = schedule .step ()
259273
@@ -274,4 +288,9 @@ def main():
274288
275289
276290if __name__ == "__main__" :
277- main ()
291+ parser = argparse .ArgumentParser ()
292+ parser .add_argument ("model_name" , type = str , help = "Name of the model to load" , choices = NAME_TO_DISTRIBUTION_AND_DTYPE .keys ())
293+ parser .add_argument ("--pp" , type = int , default = 1 , help = "Pipeline parallel degree" )
294+ args = parser .parse_args ()
295+
296+ main (args )
0 commit comments