33
44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
6+ import asyncio
7+ import atexit
8+ import importlib .util
9+ import subprocess
10+ import threading
611from abc import abstractmethod
712from collections import deque
813from dataclasses import dataclass
914from functools import partial
1015from os import environ
1116from pathlib import Path
12- from torchchat .cli .builder import BuilderArgs , TokenizerArgs
1317from typing import List , Optional
1418from uuid import uuid4
1519
16- import asyncio
17- import atexit
1820import torch .multiprocessing as mp
19- import threading
20- import importlib .util
21- import subprocess
21+ from torchchat .cli .builder import BuilderArgs , TokenizerArgs
2222
2323
24- def _setup_env (world_size :int , rank :int , target : callable , * args , ** kwargs ):
24+ def _setup_env (world_size : int , rank : int , target : callable , * args , ** kwargs ):
2525 environ ["MASTER_ADDR" ] = "localhost"
2626 environ ["MASTER_PORT" ] = "29500"
2727 environ ["RDZV_BACKEND" ] = "c10d"
@@ -36,10 +36,11 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
3636 # create programmatic elastic launch
3737 print ("Launching distributed inference ..." )
3838
39- num_processes_per_node = 4 # builder_args.num_gpus + 1
39+ num_processes_per_node = builder_args .pp * builder_args . tp
4040
4141 from torchchat .distributed .dist_run import main
42- mp .set_start_method ('spawn' )
42+
43+ mp .set_start_method ("spawn" )
4344
4445 pipes = []
4546 procs = []
@@ -48,25 +49,24 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
4849 pipes .append (server_pipe )
4950 proc = mp .Process (
5051 target = partial (_setup_env , num_processes_per_node , rank , main ),
51- args = (builder_args , client_pipe )
52+ args = (builder_args , client_pipe ),
5253 )
5354 proc .start ()
5455
55-
5656 for pipe in pipes :
5757 response = pipe .recv ()
5858
59- print (
60- f"Done launching distributed inference on **4 ** { builder_args .num_gpus } GPUs."
61- )
59+ print (f"Done launching distributed inference on { num_processes_per_node } GPUs." )
6260 return procs , pipes
6361
62+
6463@dataclass
6564class Output :
6665 is_finished : bool = False
6766 text : Optional [str ] = None
6867 token : Optional [list ] = None
6968
69+
7070@dataclass
7171class Request :
7272 request_id : int
@@ -84,7 +84,7 @@ def __init__(
8484 generator_args ,
8585 pipes ,
8686 loop ,
87- ):
87+ ):
8888 self .builder_args = builder_args
8989 self .generator_args = generator_args
9090 self .requests = {}
@@ -107,7 +107,7 @@ def process_requests_loop(self):
107107 if req == "stop" :
108108 break
109109 self .requests = {req .request_id : req .prompt }
110-
110+
111111 responses = {}
112112 running = True
113113 while running :
@@ -128,17 +128,17 @@ async def wait_for_request(self, req: Request) -> Output:
128128 yield output
129129 del self .req_to_states [req .request_id ]
130130 del self .req_to_results [req .request_id ]
131-
131+
132132 def step (self ) -> List [Output ]:
133133 responses = []
134- #TODO: Implement a scheduler to handle the requests
134+ # TODO: Implement a scheduler to handle the requests
135135 if len (self .in_flight_requests ) > 0 :
136- #Receive decoded token
136+ # Receive decoded token
137137 for p in self .pipes :
138138 p .send ("step" )
139139 for p in self .pipes :
140140 responses .append (p .recv ())
141-
141+
142142 else :
143143 # Send requests to backend
144144 self .in_flight_batch_order = list (self .requests .keys ())
@@ -148,25 +148,26 @@ def step(self) -> List[Output]:
148148 self .in_flight_requests = self .requests
149149 self .requests = {}
150150 self .current_step = 0
151- #Receive first token
151+ # Receive first token
152152 for p in self .pipes :
153153 responses .append (p .recv ())
154- responses = responses [0 ]
154+ # Filter out None responses from in-between stages
155+ responses = [r for r in responses if r is not None ][0 ]
155156 outputs = []
156157 for k , v in zip (self .in_flight_batch_order , zip (responses [0 ], responses [1 ])):
157158 text , token_ids = v
158159 outputs .append (
159160 Output (
160- is_finished = self .current_step >= self .generator_args .max_new_tokens ,
161+ is_finished = self .current_step >= self .generator_args .max_new_tokens ,
161162 text = text ,
162163 token = token_ids ,
163- )
164164 )
165+ )
165166 if self .current_step >= self .generator_args .max_new_tokens :
166167 for p in self .pipes :
167168 p .send ("stop" )
168169 self .in_flight_requests = []
169-
170+
170171 self .current_step += 1
171172
172173 return outputs
@@ -177,24 +178,28 @@ def __init__(
177178 self ,
178179 builder_args : BuilderArgs ,
179180 tokenizer_args : TokenizerArgs ,
180- #TODO: move GeneratorArgs into a different module
181+ # TODO: move GeneratorArgs into a different module
181182 generator_args ,
182183 profile : Optional [Path ],
183184 quantize : bool ,
184185 draft_quantize : bool ,
185- ):
186+ ):
186187 self .builder_args = builder_args
187188 self .generate_args = generator_args
188-
189+
190+ self .check_args ()
191+
189192 self .procs , self .pipes = _launch_distributed_inference (builder_args )
190193
191194 self .loop = asyncio .new_event_loop ()
192195 asyncio .set_event_loop (self .loop )
193196
194197 self .scheduler = Scheduler (builder_args , generator_args , self .pipes , self .loop )
195198
196- #TODO: Mode into process and use pipe or queue for comm
197- self .scheduler_thread = threading .Thread (target = self .scheduler .process_requests_loop )
199+ # TODO: Mode into process and use pipe or queue for comm
200+ self .scheduler_thread = threading .Thread (
201+ target = self .scheduler .process_requests_loop
202+ )
198203 self .scheduler_thread .start ()
199204
200205 atexit .register (self .shutdown )
@@ -220,3 +225,9 @@ def generate(self, text):
220225 running &= not output .is_finished
221226
222227 yield output
228+
229+ def check_args (self ):
230+ if self .generate_args .chat_mode :
231+ raise NotImplementedError (
232+ "Currently we only support generate with --distributed"
233+ )
0 commit comments