5555# Using model name to identify the model to load, for example "llama2-7b-chat".
5656# You can change it to other values listed below.
5757# For details on the name-to-distribution mapping, see README.md or models.json.
58+
59+ # Name : HF distribution name, dtype, and model dimension
5860NAME_TO_DISTRIBUTION_AND_DTYPE = {
59- "llama2-7b-chat" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 ),
60- "llama3" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 ),
61+ "llama2-7b-chat" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 , 4096 ),
62+ "llama3" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 , 4096 ),
63+ "llama3-70b" : ("meta-llama/Meta-Llama-3-70B-Instruct" , torch .bfloat16 , 8192 ),
6164}
6265
6366
@@ -314,8 +317,12 @@ def main(args):
314317 gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
315318 logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
316319
317- distribution , model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE [model_name ]
318- logger .info (f"Using model weights from { distribution } and dtype { model_dtype } " )
320+ distribution , model_dtype , model_dimension = NAME_TO_DISTRIBUTION_AND_DTYPE [
321+ model_name
322+ ]
323+ logger .info (
324+ f"Using model weights from { distribution } , dtype { model_dtype } and model dimension { model_dimension } "
325+ )
319326
320327 # Model-level config
321328 model_config = ModelArgs .from_name (distribution )
@@ -338,6 +345,7 @@ def main(args):
338345
339346 # Tensor parallel is enabled in this program
340347 tp_degree = world_size // pp_degree
348+ logger .info (f"Using TP degree { tp_degree } and PP degree { pp_degree } " )
341349
342350 # Create device mesh
343351 mesh_dimensions = (pp_degree , tp_degree )
@@ -388,7 +396,6 @@ def main(args):
388396 # sense. Thus it is interchangeable with micro-batch size below.
389397 batch_size = len (prompt )
390398 seqlen_prefill = 1024 # sequence length
391- dim = 4096 # embedding dimension
392399
393400 # Setup KV caches (after model distribution)
394401 # The number of cache lanes is the same as the maximum number of
@@ -419,7 +426,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
419426 0 , config .vocab_size , (batch_size , seqlen ), device = device
420427 )
421428 activation = torch .rand (
422- batch_size , seqlen , dim , device = device , dtype = model_dtype
429+ batch_size , seqlen , model_dimension , device = device , dtype = model_dtype
423430 )
424431 logits = torch .rand (
425432 batch_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
0 commit comments