1919
2020import torch .multiprocessing as mp
2121from torchchat .cli .builder import BuilderArgs , TokenizerArgs
22+ from torchchat .distributed .dist_run import NAME_TO_DISTRIBUTION_AND_DTYPE
2223
2324
2425def _setup_env (world_size : int , rank : int , target : callable , * args , ** kwargs ):
@@ -33,7 +34,7 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
3334
3435
3536def _launch_distributed_inference (
36- builder_args : BuilderArgs , tokenizer_args : TokenizerArgs
37+ model_name : str , builder_args : BuilderArgs , tokenizer_args : TokenizerArgs
3738) -> tuple [List ]:
3839 # create programmatic elastic launch
3940 print ("Launching distributed inference ..." )
@@ -51,7 +52,7 @@ def _launch_distributed_inference(
5152 pipes .append (server_pipe )
5253 proc = mp .Process (
5354 target = partial (_setup_env , num_processes_per_node , rank , main ),
54- args = (builder_args , tokenizer_args , client_pipe ),
55+ args = (model_name , builder_args , tokenizer_args , client_pipe ),
5556 )
5657 proc .start ()
5758
@@ -178,6 +179,8 @@ def step(self) -> List[Output]:
178179class DistributedGenerator (object ):
179180 def __init__ (
180181 self ,
182+ # TODO: switch this to torchchat method
183+ model_name : str ,
181184 builder_args : BuilderArgs ,
182185 tokenizer_args : TokenizerArgs ,
183186 # TODO: move GeneratorArgs into a different module
@@ -186,13 +189,14 @@ def __init__(
186189 quantize : bool ,
187190 draft_quantize : bool ,
188191 ):
192+ self .model_name = model_name
189193 self .builder_args = builder_args
190194 self .generate_args = generator_args
191195
192196 self .check_args ()
193197
194198 self .procs , self .pipes = _launch_distributed_inference (
195- builder_args , tokenizer_args
199+ model_name , builder_args , tokenizer_args
196200 )
197201
198202 self .loop = asyncio .new_event_loop ()
@@ -236,4 +240,12 @@ def check_args(self):
236240 "Currently we only support generate with --distributed"
237241 )
238242 elif self .builder_args .tp < 2 :
239- raise RuntimeError ("TP degree must be at least 2 for distributed inference" )
243+ raise ValueError ("TP degree must be at least 2 for distributed inference" )
244+ elif self .model_name not in NAME_TO_DISTRIBUTION_AND_DTYPE .keys ():
245+ raise ValueError (
246+ f"Distributed inference currently only supports then following models: { list (NAME_TO_DISTRIBUTION_AND_DTYPE .keys ())} "
247+ )
248+ elif self .builder_args .chpt_from == "torchchat" :
249+ raise ValueError (
250+ f"Distributed inference currently only supports HF checkpoints"
251+ )
0 commit comments