44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66from abc import abstractmethod
7- from typing import List , Optional
7+ from collections import deque
88from dataclasses import dataclass
9- from pathlib import Path
9+ from functools import partial
1010from os import environ
11+ from pathlib import Path
1112from torchchat .cli .builder import BuilderArgs , TokenizerArgs
12- from functools import partial
13+ from typing import List , Optional
14+ from uuid import uuid4
1315
16+ import asyncio
1417import atexit
1518import torch .multiprocessing as mp
19+ import threading
1620import importlib .util
1721import subprocess
1822
@@ -51,7 +55,6 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
5155
5256 for pipe in pipes :
5357 response = pipe .recv ()
54- print (f"Received: { response = } " )
5558
5659 print (
5760 f"Done launching distributed inference on **4 ** { builder_args .num_gpus } GPUs."
@@ -60,56 +63,72 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
6063
6164@dataclass
6265class Output :
63- request_id : int
6466 is_finished : bool = False
65- output : Optional [str ] = None
66-
67- class Generator (object ):
67+ text : Optional [str ] = None
68+ token : Optional [list ] = None
6869
69- @abstractmethod
70- def add_request (self , request_id : int , prompt : str ):
71- raise NotImplementedError ()
70+ @dataclass
71+ class Request :
72+ request_id : int
73+ prompt : str
7274
73- def step (self ) -> List [Output ]:
74- raise NotImplementedError ()
75+ @classmethod
76+ def new_request (cls , prompt ):
77+ return cls (request_id = uuid4 ().int , prompt = prompt )
7578
7679
77- class DistributedGenerator ( Generator ):
80+ class Scheduler ( object ):
7881 def __init__ (
7982 self ,
80- builder_args : BuilderArgs ,
81- speculative_builder_args : BuilderArgs ,
82- tokenizer_args : TokenizerArgs ,
83- #TODO: move GeneratorArgs into a different module
84- # generator_args: GeneratorArgs,
85- profile : Optional [Path ],
86- quantize : bool ,
87- draft_quantize : bool ,
83+ builder_args ,
84+ generator_args ,
85+ pipes ,
86+ loop ,
8887 ):
8988 self .builder_args = builder_args
89+ self .generator_args = generator_args
9090 self .requests = {}
9191 self .in_flight_requests = {}
92- # For now we have a static batch order we save separately
9392 self .in_flight_batch_order = []
94- # if builder_args.distributed:
95- # # we part ways here with torchchat cli and move into dist inference
96- self .procs , self .pipes = _launch_distributed_inference (builder_args )
97- self .current_step = 0
98-
99- atexit .register (self .shutdown )
100-
101- def shutdown (self ):
102- for p in self .pipes :
103- p .send ("stop" )
104- for p in self .procs :
105- p .kill ()
106-
107- #TODO: Replace against (async) generate
108- def add_request (self , request_id : int , prompt : str ):
109- assert request_id not in self .requests
110- self .requests [request_id ] = prompt
111-
112-
93+ self .pipes = pipes
94+ self .req_to_states = {}
95+ self .req_to_results = {}
96+ self .request_queue = mp .Queue ()
97+ self .loop = loop
98+
99+ def schedule_request (self , req : Request ):
100+ self .req_to_states [req .request_id ] = asyncio .Event ()
101+ self .req_to_results [req .request_id ] = deque ()
102+ self .request_queue .put (req )
103+
104+ def process_requests_loop (self ):
105+ while True :
106+ req = self .request_queue .get ()
107+ if req == "stop" :
108+ break
109+ self .requests = {req .request_id : req .prompt }
110+
111+ responses = {}
112+ running = True
113+ while running :
114+ outputs = self .step ()
115+ self .req_to_results [req .request_id ].append (outputs [0 ])
116+
117+ self .loop .call_soon_threadsafe (self .req_to_states [req .request_id ].set )
118+
119+ running &= not outputs [0 ].is_finished
120+
121+ async def wait_for_request (self , req : Request ) -> Output :
122+ is_finished = False
123+ while not is_finished :
124+ await self .req_to_states [req .request_id ].wait ()
125+ while len (self .req_to_results [req .request_id ]):
126+ output = self .req_to_results [req .request_id ].popleft ()
127+ is_finished |= output .is_finished
128+ yield output
129+ del self .req_to_states [req .request_id ]
130+ del self .req_to_results [req .request_id ]
131+
113132 def step (self ) -> List [Output ]:
114133 responses = []
115134 #TODO: Implement a scheduler to handle the requests
@@ -132,12 +151,72 @@ def step(self) -> List[Output]:
132151 #Receive first token
133152 for p in self .pipes :
134153 responses .append (p .recv ())
135-
136154 responses = responses [0 ]
137155 outputs = []
138- for k , v in zip (self .in_flight_batch_order , responses ):
139- outputs .append (Output (k , is_finished = self .current_step >= self .builder_args .ntokens , output = v ))
156+ for k , v in zip (self .in_flight_batch_order , zip (responses [0 ], responses [1 ])):
157+ text , token_ids = v
158+ outputs .append (
159+ Output (
160+ is_finished = self .current_step >= self .generator_args .max_new_tokens ,
161+ text = text ,
162+ token = token_ids ,
163+ )
164+ )
165+ if self .current_step >= self .generator_args .max_new_tokens :
166+ for p in self .pipes :
167+ p .send ("stop" )
168+ self .in_flight_requests = []
140169
141170 self .current_step += 1
142171
143172 return outputs
173+
174+
175+ class DistributedGenerator (object ):
176+ def __init__ (
177+ self ,
178+ builder_args : BuilderArgs ,
179+ tokenizer_args : TokenizerArgs ,
180+ #TODO: move GeneratorArgs into a different module
181+ generator_args ,
182+ profile : Optional [Path ],
183+ quantize : bool ,
184+ draft_quantize : bool ,
185+ ):
186+ self .builder_args = builder_args
187+ self .generate_args = generator_args
188+
189+ self .procs , self .pipes = _launch_distributed_inference (builder_args )
190+
191+ self .loop = asyncio .new_event_loop ()
192+ asyncio .set_event_loop (self .loop )
193+
194+ self .scheduler = Scheduler (builder_args , generator_args , self .pipes , self .loop )
195+
196+ #TODO: Mode into process and use pipe or queue for comm
197+ self .scheduler_thread = threading .Thread (target = self .scheduler .process_requests_loop )
198+ self .scheduler_thread .start ()
199+
200+ atexit .register (self .shutdown )
201+
202+ def shutdown (self ):
203+ self .scheduler .request_queue .put ("stop" )
204+ self .scheduler_thread .join ()
205+
206+ for p in self .pipes :
207+ p .send ("stop" )
208+ for p in self .procs :
209+ p .kill ()
210+
211+ def generate (self , text ):
212+ req = Request .new_request (text )
213+ self .scheduler .schedule_request (req )
214+
215+ generator = self .scheduler .wait_for_request (req )
216+
217+ running = True
218+ while running :
219+ output = self .loop .run_until_complete (generator .__anext__ ())
220+ running &= not output .is_finished
221+
222+ yield output
0 commit comments