diff --git a/dist_run.py b/dist_run.py index 30bf92669..2b4ab67cb 100644 --- a/dist_run.py +++ b/dist_run.py @@ -55,9 +55,12 @@ # Using model name to identify the model to load, for example "llama2-7b-chat". # You can change it to other values listed below. # For details on the name-to-distribution mapping, see README.md or models.json. + +# Name : HF distribution name, dtype, and model dimension NAME_TO_DISTRIBUTION_AND_DTYPE = { - "llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), - "llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), + "llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16, 4096), + "llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16, 4096), + "llama3-70b": ("meta-llama/Meta-Llama-3-70B-Instruct", torch.bfloat16, 8192), } @@ -314,8 +317,12 @@ def main(args): gpu_memory_monitor = GPUMemoryMonitor("cuda") logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") - distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] - logger.info(f"Using model weights from {distribution} and dtype {model_dtype}") + distribution, model_dtype, model_dimension = NAME_TO_DISTRIBUTION_AND_DTYPE[ + model_name + ] + logger.info( + f"Using model weights from {distribution}, dtype {model_dtype} and model dimension {model_dimension}" + ) # Model-level config model_config = ModelArgs.from_name(distribution) @@ -338,6 +345,7 @@ def main(args): # Tensor parallel is enabled in this program tp_degree = world_size // pp_degree + logger.info(f"Using TP degree {tp_degree} and PP degree {pp_degree}") # Create device mesh mesh_dimensions = (pp_degree, tp_degree) @@ -388,7 +396,6 @@ def main(args): # sense. Thus it is interchangeable with micro-batch size below. batch_size = len(prompt) seqlen_prefill = 1024 # sequence length - dim = 4096 # embedding dimension # Setup KV caches (after model distribution) # The number of cache lanes is the same as the maximum number of @@ -419,7 +426,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: 0, config.vocab_size, (batch_size, seqlen), device=device ) activation = torch.rand( - batch_size, seqlen, dim, device=device, dtype=model_dtype + batch_size, seqlen, model_dimension, device=device, dtype=model_dtype ) logits = torch.rand( batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype diff --git a/torchchat/distributed/force_download.py b/torchchat/distributed/force_download.py index 76dba8d0c..c57509c70 100644 --- a/torchchat/distributed/force_download.py +++ b/torchchat/distributed/force_download.py @@ -1,5 +1,5 @@ from transformers import AutoModelForCausalLM, AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") -model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct") print("Model weights and tokenizer downloaded")