@@ -123,10 +123,34 @@ async def resolve(
123
123
...
124
124
125
125
async def get_request (
126
- self ,
127
- requests_queue : multiprocessing .Queue ,
126
+ self , requests_queue : multiprocessing .Queue ,
127
+ shutdown_event : MultiprocessingEvent ,
128
+ process_id : int ,
129
+ shutdown_poll_interval_seconds : float ,
128
130
) -> Optional [WorkerProcessRequest [RequestT ]]:
129
- return await asyncio .to_thread (requests_queue .get ) # type: ignore[attr-defined]
131
+ # We need to check shutdown_event intermittently cause
132
+ # if we simply use asyncio.to_thread(requests_queue.get)
133
+ # the cancellation task doesn't propagate because the
134
+ # asyncio.to_thread is blocking
135
+ return await asyncio .to_thread (requests_queue .get )
136
+ # def _get_queue_intermittently():
137
+ # while True:
138
+ # try:
139
+ # return requests_queue.get(timeout=shutdown_poll_interval_seconds)
140
+ # except queue.Empty:
141
+ # logger.info("Checking shutdown even is set in get_request")
142
+ # if shutdown_event.is_set():
143
+ # logger.info(f"Shutdown signal received in future {process_id}")
144
+ # raise asyncio.CancelledError()
145
+ # # return None
146
+ #
147
+ # try:
148
+ # return await asyncio.to_thread(_get_queue_intermittently) # type: ignore[attr-defined]
149
+ # except asyncio.CancelledError:
150
+ # logger.info("kaki")
151
+ # # return None
152
+ # raise
153
+ # # raise
130
154
131
155
async def send_result (
132
156
self ,
@@ -203,11 +227,15 @@ def run_process(
203
227
max_concurrency : Optional [int ] = None ,
204
228
):
205
229
async def _process_runner ():
230
+ import threading
231
+ internal_shutdown_event = threading .Event ()
206
232
if type_ == "sync" :
207
233
loop_task = asyncio .create_task (self ._process_synchronous_requests_loop (
208
234
requests_queue = requests_queue ,
209
235
results_queue = results_queue ,
210
236
process_id = process_id ,
237
+ shutdown_event = internal_shutdown_event ,
238
+ shutdown_poll_interval_seconds = shutdown_poll_interval_seconds ,
211
239
), name = "request_loop_processor_task" )
212
240
elif type_ == "async" :
213
241
if max_concurrency is None :
@@ -218,6 +246,8 @@ async def _process_runner():
218
246
results_queue = results_queue ,
219
247
max_concurrency = max_concurrency ,
220
248
process_id = process_id ,
249
+ shutdown_event = internal_shutdown_event ,
250
+ shutdown_poll_interval_seconds = shutdown_poll_interval_seconds ,
221
251
), name = "request_loop_processor_task" )
222
252
else :
223
253
raise ValueError (f"Invalid process type: { type_ } " )
@@ -237,10 +267,12 @@ async def _process_runner():
237
267
],
238
268
return_when = asyncio .FIRST_EXCEPTION ,
239
269
)
270
+ logger .info ("First exception happened" )
240
271
241
272
for task in pending :
242
273
logger .debug (f"Cancelling task { task .get_name ()} " )
243
274
cancel_result = task .cancel ()
275
+ internal_shutdown_event .set ()
244
276
logger .debug (f"{ 'Task is already done or canceled' if not cancel_result else 'sent cancel signal' } " )
245
277
try :
246
278
await task
@@ -271,18 +303,22 @@ async def _wait_for_shutdown(
271
303
await asyncio .sleep (shutdown_poll_interval )
272
304
273
305
logger .debug ("Shutdown signal received" )
274
- raise ValueError ("kaki" )
275
306
raise asyncio .CancelledError ("Shutdown event set, cancelling process loop." )
276
307
277
308
async def _process_synchronous_requests_loop (
278
309
self ,
279
310
requests_queue : multiprocessing .Queue ,
280
311
results_queue : multiprocessing .Queue ,
281
312
process_id : int ,
313
+ shutdown_event : MultiprocessingEvent ,
314
+ shutdown_poll_interval_seconds : float ,
282
315
):
283
316
while True :
284
317
process_request = await self .get_request (
285
318
requests_queue = requests_queue ,
319
+ shutdown_event = shutdown_event ,
320
+ process_id = process_id ,
321
+ shutdown_poll_interval_seconds = shutdown_poll_interval_seconds
286
322
)
287
323
288
324
dequeued_time = time .time ()
@@ -297,15 +333,14 @@ async def _process_synchronous_requests_loop(
297
333
process_id = process_id ,
298
334
)
299
335
300
- logger .debug ("Done processing synchronous loop" )
301
-
302
-
303
336
async def _process_asynchronous_requests_loop (
304
337
self ,
305
338
requests_queue : multiprocessing .Queue ,
306
339
results_queue : multiprocessing .Queue ,
307
340
max_concurrency : int ,
308
341
process_id : int ,
342
+ shutdown_event : MultiprocessingEvent ,
343
+ shutdown_poll_interval_seconds : float ,
309
344
):
310
345
pending = asyncio .Semaphore (max_concurrency )
311
346
@@ -316,6 +351,9 @@ async def _process_asynchronous_requests_loop(
316
351
logger .info ("Awaiting request..." )
317
352
process_request = await self .get_request (
318
353
requests_queue = requests_queue ,
354
+ shutdown_event = shutdown_event ,
355
+ process_id = process_id ,
356
+ shutdown_poll_interval_seconds = shutdown_poll_interval_seconds ,
319
357
)
320
358
321
359
dequeued_time = time .time ()
@@ -351,8 +389,6 @@ def _task_done(_: asyncio.Task):
351
389
task .add_done_callback (_task_done )
352
390
await asyncio .sleep (0 ) # enable start task immediately
353
391
354
- logger .debug ("Done processing asynchronous loop" )
355
-
356
392
357
393
class GenerativeRequestsWorkerDescription (WorkerDescription ):
358
394
type_ : Literal ["generative_requests_worker" ] = "generative_requests_worker" # type: ignore[assignment]
0 commit comments