diff --git a/dist_run.py b/dist_run.py index 06cfa341c..79a3d2f84 100644 --- a/dist_run.py +++ b/dist_run.py @@ -4,10 +4,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import argparse import os from pathlib import Path from types import SimpleNamespace -from typing import Any, Dict +from typing import Any, Dict, Optional # Run command: # torchrun --nproc-per-node 4 dist_run.py @@ -52,10 +53,12 @@ logger = SingletonLogger.get_logger() -MODEL_NAME = "Transformer-2-7b-chat-hf" -NAME_TO_HF_MODEL_ID_AND_DTYPE = { - "Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), - "Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), +# Using model name to identify the model to load, for example "llama2-7b-chat". +# You can change it to other values listed below. +# For details on the name-to-distribution mapping, see README.md or models.json. +NAME_TO_DISTRIBUTION_AND_DTYPE = { + "llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), + "llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), } CACHE_PRECISION = torch.bfloat16 @@ -78,8 +81,19 @@ def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: def _build_chat_tokenizer( - model_base_name: str = "llama3", + model_name: str, + model_base_name: Optional[str] = None, ) -> SentencePieceProcessor | TiktokenTokenizer: + """Builds a tokenizer for the given model name.""" + # Try to infer the model base name from the model name: + # e.g. "llama2-7b-chat" -> "llama2" + if model_base_name is None: + model_base_name = model_name.split("-")[0] + logger.info( + f"Using model base name '{model_base_name}' to build tokenizer. " + "If not found, please specify it using the `model_base_name` argument." + ) + # Create base args for tokenizer default_model_dir = Path( os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache") @@ -100,12 +114,12 @@ def _build_chat_tokenizer( return tokenizer -def _load_model_weights(stage_module, hf_model_name, device, model_config): +def _load_model_weights(stage_module, distribution, device, model_config): """Load the weights from the safetensor file(s) into the model stage. Model config is needed b/c we permute wq and wk weights based on attn heads. """ - weight_map, weight_path, key_map = get_hf_weight_map_and_path(hf_model_name) + weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution) num_loaded_weights, num_missing_weights = load_safetensor_weights( stage_module, @@ -127,32 +141,31 @@ def _cleanup(): dist.destroy_process_group() -def main(): +def main(args): + model_name = args.model_name + pp_degree = args.pp + rank, world_size = _init_distributed() gpu_memory_monitor = GPUMemoryMonitor("cuda") logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") - config = ModelArgs.from_name(MODEL_NAME).transformer_args['text'] - logger.info(f"Chat Model Config: {config}") + distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] + logger.info(f"Using HF model weights from {distribution} and dtype {model_dtype}") - tokenizer = _build_chat_tokenizer() - logger.info(f"built tokenizer {tokenizer=}") + config = ModelArgs.from_name(distribution).transformer_args['text'] + logger.info(f"Chat Model Config: {config}") - hf_model_name, model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE[MODEL_NAME] - logger.info(f"Using HF model weights from {hf_model_name} and dtype {model_dtype}") + tokenizer = _build_chat_tokenizer(model_name) set_precision(CACHE_PRECISION) logger.info(f"Using cache precision {CACHE_PRECISION}") - hf_config = get_hf_config_file(hf_model_name) + hf_config = get_hf_config_file(distribution) if hf_config is None: - raise ValueError(f"Config file not found for model id {hf_model_name}") - logger.info(f"Using HF model weights from {hf_model_name}") + raise ValueError(f"Config file not found for model id {distribution}") - # Assuming 2 pipeline stages, feel free to change this as long as the - # asserts are satisfied - pp_degree = 2 + # Validate pipeline degree assert world_size % pp_degree == 0 assert config.n_layers % pp_degree == 0 @@ -182,7 +195,8 @@ def main(): # Distribute model on TP mesh model.distribute(tp_mesh) - logger.info(f"Model: {model}") + if rank == 0: + logger.info(f"Model: {model}") mbs = 2 # number of micro-batches mb_size = 1 # micro-batch size @@ -200,7 +214,7 @@ def main(): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with TrackTime("cuda") as timer: - _load_model_weights(model, hf_model_name, device=device, model_config=config) + _load_model_weights(model, distribution, device=device, model_config=config) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" ) @@ -253,7 +267,7 @@ def main(): with torch.no_grad(): # .inference_mode(): if pp_rank == 0: - schedule.step(input_ids) + output = schedule.step(input_ids) else: output = schedule.step() @@ -274,4 +288,9 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("model_name", type=str, help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys()) + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") + args = parser.parse_args() + + main(args)