9
9
from collections .abc import AsyncGenerator
10
10
from dataclasses import dataclass
11
11
from multiprocessing .synchronize import Event as MultiprocessingEvent
12
- from threading import Event
13
12
from typing import (
14
13
Any ,
15
14
Generic ,
@@ -131,25 +130,30 @@ async def resolve(
131
130
async def get_request (
132
131
self ,
133
132
requests_queue : multiprocessing .Queue ,
134
- shutdown_event : threading .Event ,
135
133
process_id : int ,
136
134
shutdown_poll_interval_seconds : float ,
137
135
) -> WorkerProcessRequest [RequestT ]:
136
+ shutdown_event = threading .Event ()
137
+
138
138
# We need to check shutdown_event intermittently cause
139
139
# if we simply use asyncio.to_thread(requests_queue.get)
140
140
# the cancellation task doesn't propagate because the
141
141
# asyncio.to_thread is blocking
142
142
def _get_queue_intermittently ():
143
- while True :
143
+ while not shutdown_event . is_set () :
144
144
try :
145
- return requests_queue .get (timeout = shutdown_poll_interval_seconds )
146
- except queue .Empty as e :
147
- logger .info ("Checking shutdown even is set in get_request" )
148
- if shutdown_event .is_set ():
149
- logger .info (f"Shutdown signal received in future { process_id } " )
150
- raise asyncio .CancelledError from e
145
+ request = requests_queue .get (timeout = shutdown_poll_interval_seconds )
146
+ logger .debug (f"Gor request in future { process_id } " )
147
+ return request
148
+ except queue .Empty :
149
+ logger .trace (f"Queue was empty in future { process_id } " )
150
+ logger .info (f"Shutdown signal received in future { process_id } " )
151
+ return None
151
152
152
- return await asyncio .to_thread (_get_queue_intermittently ) # type: ignore[attr-defined]
153
+ try :
154
+ return await asyncio .to_thread (_get_queue_intermittently )
155
+ finally :
156
+ shutdown_event .set ()
153
157
154
158
async def send_result (
155
159
self ,
@@ -175,15 +179,17 @@ async def resolve_scheduler_request(
175
179
scheduled_time = time .time (),
176
180
process_id = process_id ,
177
181
)
178
- request_scheduled_result : WorkerProcessResult [RequestT , ResponseT ] = (
179
- WorkerProcessResult (
180
- type_ = "request_scheduled" ,
181
- request = request ,
182
- response = None ,
183
- info = info ,
182
+ asyncio .create_task (
183
+ self .send_result (
184
+ results_queue ,
185
+ WorkerProcessResult (
186
+ type_ = "request_scheduled" ,
187
+ request = request ,
188
+ response = None ,
189
+ info = info ,
190
+ ),
184
191
)
185
192
)
186
- asyncio .create_task (self .send_result (results_queue , request_scheduled_result ))
187
193
188
194
if (wait_time := start_time - time .time ()) > 0 :
189
195
await asyncio .sleep (wait_time )
@@ -223,37 +229,26 @@ def run_process(
223
229
shutdown_event : MultiprocessingEvent ,
224
230
shutdown_poll_interval_seconds : float ,
225
231
process_id : int ,
226
- max_concurrency : Optional [ int ] = None ,
232
+ max_concurrency : int ,
227
233
):
228
234
async def _process_runner ():
229
- # We are using a separate internal event
230
- # because if we're using the shutdown_event
231
- # there's a race condition between the get_request
232
- # loop which checks for shutdown and the .cancel() in this
233
- # method which causes the asyncio.CancelledError
234
- # to propagate and crash the worker
235
- internal_shutdown_event : threading .Event = Event ()
236
235
if type_ == "sync" :
237
236
loop_task = asyncio .create_task (
238
237
self ._process_synchronous_requests_loop (
239
238
requests_queue = requests_queue ,
240
239
results_queue = results_queue ,
241
240
process_id = process_id ,
242
- shutdown_event = internal_shutdown_event ,
243
241
shutdown_poll_interval_seconds = shutdown_poll_interval_seconds ,
244
242
),
245
243
name = "request_loop_processor_task" ,
246
244
)
247
245
elif type_ == "async" :
248
- if max_concurrency is None :
249
- raise ValueError ("max_concurrency must be set for async processor" )
250
246
loop_task = asyncio .create_task (
251
247
self ._process_asynchronous_requests_loop (
252
248
requests_queue = requests_queue ,
253
249
results_queue = results_queue ,
254
250
max_concurrency = max_concurrency ,
255
251
process_id = process_id ,
256
- shutdown_event = internal_shutdown_event ,
257
252
shutdown_poll_interval_seconds = shutdown_poll_interval_seconds ,
258
253
),
259
254
name = "request_loop_processor_task" ,
@@ -286,7 +281,6 @@ async def _process_runner():
286
281
f"Cancelling task { task .get_name ()} || Process { process_id } "
287
282
)
288
283
task .cancel ()
289
- internal_shutdown_event .set ()
290
284
try : # noqa: SIM105
291
285
await task
292
286
except asyncio .CancelledError :
@@ -317,9 +311,6 @@ async def _wait_for_shutdown(
317
311
while not shutdown_event .is_set (): # noqa: ASYNC110
318
312
await asyncio .sleep (shutdown_poll_interval )
319
313
320
- # Raising asyncio.CancelledError instead would
321
- # cause the asyncio.wait above to wait
322
- # forever, couldn't find a reasonable reason why
323
314
raise ShutdownSignalReceivedError (
324
315
f"Shutdown event set for process { process_id } , cancelling process loop."
325
316
)
@@ -329,13 +320,11 @@ async def _process_synchronous_requests_loop(
329
320
requests_queue : multiprocessing .Queue ,
330
321
results_queue : multiprocessing .Queue ,
331
322
process_id : int ,
332
- shutdown_event : threading .Event ,
333
323
shutdown_poll_interval_seconds : float ,
334
324
):
335
325
while True :
336
326
process_request = await self .get_request (
337
327
requests_queue = requests_queue ,
338
- shutdown_event = shutdown_event ,
339
328
process_id = process_id ,
340
329
shutdown_poll_interval_seconds = shutdown_poll_interval_seconds ,
341
330
)
@@ -358,7 +347,6 @@ async def _process_asynchronous_requests_loop(
358
347
results_queue : multiprocessing .Queue ,
359
348
max_concurrency : int ,
360
349
process_id : int ,
361
- shutdown_event : threading .Event ,
362
350
shutdown_poll_interval_seconds : float ,
363
351
):
364
352
pending = asyncio .Semaphore (max_concurrency )
@@ -369,7 +357,6 @@ async def _process_asynchronous_requests_loop(
369
357
while True :
370
358
process_request = await self .get_request (
371
359
requests_queue = requests_queue ,
372
- shutdown_event = shutdown_event ,
373
360
process_id = process_id ,
374
361
shutdown_poll_interval_seconds = shutdown_poll_interval_seconds ,
375
362
)
@@ -461,7 +448,7 @@ def run_process(
461
448
shutdown_event : MultiprocessingEvent ,
462
449
shutdown_poll_interval_seconds : float ,
463
450
process_id : int ,
464
- max_concurrency : Optional [ int ] = None ,
451
+ max_concurrency : int ,
465
452
):
466
453
asyncio .run (self .backend .validate ())
467
454
super ().run_process (
0 commit comments