@@ -64,6 +64,7 @@ class BuilderArgs:
6464 pp : int = 1
6565 tp : int = 1
6666 chpt_from : str = "hf"
67+ distribution_path : Optional [str ] = None
6768 is_chat_model : bool = False
6869 prefill_possible : bool = False
6970 dynamic_shapes : bool = False
@@ -129,6 +130,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
129130 model_config .transformer_params_key or model_config .name .split ("/" )[- 1 ]
130131 )
131132
133+ distribution_path = model_config .distribution_path
134+
132135 dso_path = getattr (args , "dso_path" , None )
133136 pte_path = getattr (args , "pte_path" , None )
134137 aoti_package_path = getattr (args , "aoti_package_path" , None )
@@ -194,6 +197,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
194197 pp = pp ,
195198 tp = tp ,
196199 chpt_from = chpt_from ,
200+ distribution_path = distribution_path ,
197201 is_chat_model = is_chat_model ,
198202 dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
199203 max_seq_length = getattr (args , "max_seq_length" , None ),
@@ -607,23 +611,6 @@ def do_nothing(max_batch_size, max_seq_length):
607611 except Exception :
608612 raise RuntimeError (f"Failed to load ET compiled { builder_args .pte_path } " )
609613 elif builder_args .distributed :
610- # Using params_table to identify the model to load, for example "Meta-Llama-3.1-8B".
611- #TODO This is a hacky way to please the distributed loading api and needs to be replaced
612- NAME_TO_DISTRIBUTION = {
613- "Meta-Llama-3-8B" : "meta-llama/Meta-Llama-3-8B-Instruct" ,
614- "Meta-Llama-3.1-8B" : "meta-llama/Meta-Llama-3.1-8B-Instruct" ,
615- "Meta-Llama-3-70B" : "meta-llama/Meta-Llama-3-70B-Instruct" ,
616- "Meta-Llama-3.1-70B" : "meta-llama/Meta-Llama-3.1-70B-Instruct" ,
617-
618- }
619- # TODO: Use information in builder_args directly to build model and load weights
620- assert builder_args .params_table
621- try :
622- distribution = NAME_TO_DISTRIBUTION [builder_args .params_table ]
623- except KeyError as e :
624- print (f"Unknown params_table: { builder_args .params_table } . Suported model names are: llama3.1, llama3, llama2-7b-chat" )
625- raise e
626-
627614 pp_degree = builder_args .pp
628615 tp_degree = builder_args .tp
629616
@@ -687,7 +674,7 @@ def do_nothing(max_batch_size, max_seq_length):
687674 # Load weights
688675 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
689676 with CUDATrackTime () as timer :
690- load_model_weights (model , distribution , device , config , builder_args .chpt_from )
677+ load_model_weights (model , builder_args . distribution_path , device , config , builder_args .chpt_from )
691678
692679 logger .info (
693680 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
0 commit comments