1616import torch ._inductor .config
1717import torch .nn as nn
1818
19- from torch .distributed import launcher
20-
2119from torch .distributed .device_mesh import DeviceMesh
2220from torch .distributed .elastic .multiprocessing .errors import record
2321from torch .distributed .elastic .utils .distributed import get_free_port
24- from torch .distributed .launcher .api import elastic_launch
2522
2623from torchchat .distributed import launch_distributed , ParallelDims , parallelize_llama
2724
@@ -65,6 +62,8 @@ class BuilderArgs:
6562 num_nodes : int = 1
6663 pp : int = 1
6764 tp : int = 1
65+ chpt_from : str = "hf"
66+ ntokens : int = 40
6867 is_chat_model : bool = False
6968 prefill_possible : bool = False
7069 dynamic_shapes : bool = False
@@ -171,6 +170,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
171170 num_nodes = getattr (args , "num_nodes" , 1 )
172171 pp = getattr (args , "pp" , 1 )
173172 tp = getattr (args , "tp" , 1 )
173+ chpt_from = getattr (args , "chpt_from" , "hf" )
174+ ntokens = getattr (args , "ntokens" , 40 )
174175 return cls (
175176 checkpoint_dir = checkpoint_dir ,
176177 checkpoint_path = checkpoint_path ,
@@ -189,6 +190,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
189190 num_nodes = num_nodes ,
190191 pp = pp ,
191192 tp = tp ,
193+ chpt_from = chpt_from ,
194+ ntokens = ntokens ,
192195 is_chat_model = is_chat_model ,
193196 dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
194197 max_seq_length = getattr (args , "max_seq_length" , None ),
@@ -508,7 +511,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
508511
509512 model = model .to (device = builder_args .device , dtype = builder_args .precision )
510513 return model .eval ()
511-
514+
512515
513516def _initialize_model (
514517 builder_args : BuilderArgs ,
0 commit comments