Skip to content

Commit 47e7155

Browse files
fix: final PR comments
1 parent e2b3691 commit 47e7155

File tree

1 file changed

+25
-38
lines changed

1 file changed

+25
-38
lines changed

src/guidellm/scheduler/worker.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections.abc import AsyncGenerator
1010
from dataclasses import dataclass
1111
from multiprocessing.synchronize import Event as MultiprocessingEvent
12-
from threading import Event
1312
from typing import (
1413
Any,
1514
Generic,
@@ -131,25 +130,30 @@ async def resolve(
131130
async def get_request(
132131
self,
133132
requests_queue: multiprocessing.Queue,
134-
shutdown_event: threading.Event,
135133
process_id: int,
136134
shutdown_poll_interval_seconds: float,
137135
) -> WorkerProcessRequest[RequestT]:
136+
shutdown_event = threading.Event()
137+
138138
# We need to check shutdown_event intermittently cause
139139
# if we simply use asyncio.to_thread(requests_queue.get)
140140
# the cancellation task doesn't propagate because the
141141
# asyncio.to_thread is blocking
142142
def _get_queue_intermittently():
143-
while True:
143+
while not shutdown_event.is_set():
144144
try:
145-
return requests_queue.get(timeout=shutdown_poll_interval_seconds)
146-
except queue.Empty as e:
147-
logger.info("Checking shutdown even is set in get_request")
148-
if shutdown_event.is_set():
149-
logger.info(f"Shutdown signal received in future {process_id}")
150-
raise asyncio.CancelledError from e
145+
request = requests_queue.get(timeout=shutdown_poll_interval_seconds)
146+
logger.debug(f"Gor request in future {process_id}")
147+
return request
148+
except queue.Empty:
149+
logger.trace(f"Queue was empty in future {process_id}")
150+
logger.info(f"Shutdown signal received in future {process_id}")
151+
return None
151152

152-
return await asyncio.to_thread(_get_queue_intermittently) # type: ignore[attr-defined]
153+
try:
154+
return await asyncio.to_thread(_get_queue_intermittently)
155+
finally:
156+
shutdown_event.set()
153157

154158
async def send_result(
155159
self,
@@ -175,15 +179,17 @@ async def resolve_scheduler_request(
175179
scheduled_time=time.time(),
176180
process_id=process_id,
177181
)
178-
request_scheduled_result: WorkerProcessResult[RequestT, ResponseT] = (
179-
WorkerProcessResult(
180-
type_="request_scheduled",
181-
request=request,
182-
response=None,
183-
info=info,
182+
asyncio.create_task(
183+
self.send_result(
184+
results_queue,
185+
WorkerProcessResult(
186+
type_="request_scheduled",
187+
request=request,
188+
response=None,
189+
info=info,
190+
),
184191
)
185192
)
186-
asyncio.create_task(self.send_result(results_queue, request_scheduled_result))
187193

188194
if (wait_time := start_time - time.time()) > 0:
189195
await asyncio.sleep(wait_time)
@@ -223,37 +229,26 @@ def run_process(
223229
shutdown_event: MultiprocessingEvent,
224230
shutdown_poll_interval_seconds: float,
225231
process_id: int,
226-
max_concurrency: Optional[int] = None,
232+
max_concurrency: int,
227233
):
228234
async def _process_runner():
229-
# We are using a separate internal event
230-
# because if we're using the shutdown_event
231-
# there's a race condition between the get_request
232-
# loop which checks for shutdown and the .cancel() in this
233-
# method which causes the asyncio.CancelledError
234-
# to propagate and crash the worker
235-
internal_shutdown_event: threading.Event = Event()
236235
if type_ == "sync":
237236
loop_task = asyncio.create_task(
238237
self._process_synchronous_requests_loop(
239238
requests_queue=requests_queue,
240239
results_queue=results_queue,
241240
process_id=process_id,
242-
shutdown_event=internal_shutdown_event,
243241
shutdown_poll_interval_seconds=shutdown_poll_interval_seconds,
244242
),
245243
name="request_loop_processor_task",
246244
)
247245
elif type_ == "async":
248-
if max_concurrency is None:
249-
raise ValueError("max_concurrency must be set for async processor")
250246
loop_task = asyncio.create_task(
251247
self._process_asynchronous_requests_loop(
252248
requests_queue=requests_queue,
253249
results_queue=results_queue,
254250
max_concurrency=max_concurrency,
255251
process_id=process_id,
256-
shutdown_event=internal_shutdown_event,
257252
shutdown_poll_interval_seconds=shutdown_poll_interval_seconds,
258253
),
259254
name="request_loop_processor_task",
@@ -286,7 +281,6 @@ async def _process_runner():
286281
f"Cancelling task {task.get_name()}|| Process {process_id}"
287282
)
288283
task.cancel()
289-
internal_shutdown_event.set()
290284
try: # noqa: SIM105
291285
await task
292286
except asyncio.CancelledError:
@@ -317,9 +311,6 @@ async def _wait_for_shutdown(
317311
while not shutdown_event.is_set(): # noqa: ASYNC110
318312
await asyncio.sleep(shutdown_poll_interval)
319313

320-
# Raising asyncio.CancelledError instead would
321-
# cause the asyncio.wait above to wait
322-
# forever, couldn't find a reasonable reason why
323314
raise ShutdownSignalReceivedError(
324315
f"Shutdown event set for process {process_id}, cancelling process loop."
325316
)
@@ -329,13 +320,11 @@ async def _process_synchronous_requests_loop(
329320
requests_queue: multiprocessing.Queue,
330321
results_queue: multiprocessing.Queue,
331322
process_id: int,
332-
shutdown_event: threading.Event,
333323
shutdown_poll_interval_seconds: float,
334324
):
335325
while True:
336326
process_request = await self.get_request(
337327
requests_queue=requests_queue,
338-
shutdown_event=shutdown_event,
339328
process_id=process_id,
340329
shutdown_poll_interval_seconds=shutdown_poll_interval_seconds,
341330
)
@@ -358,7 +347,6 @@ async def _process_asynchronous_requests_loop(
358347
results_queue: multiprocessing.Queue,
359348
max_concurrency: int,
360349
process_id: int,
361-
shutdown_event: threading.Event,
362350
shutdown_poll_interval_seconds: float,
363351
):
364352
pending = asyncio.Semaphore(max_concurrency)
@@ -369,7 +357,6 @@ async def _process_asynchronous_requests_loop(
369357
while True:
370358
process_request = await self.get_request(
371359
requests_queue=requests_queue,
372-
shutdown_event=shutdown_event,
373360
process_id=process_id,
374361
shutdown_poll_interval_seconds=shutdown_poll_interval_seconds,
375362
)
@@ -461,7 +448,7 @@ def run_process(
461448
shutdown_event: MultiprocessingEvent,
462449
shutdown_poll_interval_seconds: float,
463450
process_id: int,
464-
max_concurrency: Optional[int] = None,
451+
max_concurrency: int,
465452
):
466453
asyncio.run(self.backend.validate())
467454
super().run_process(

0 commit comments

Comments
 (0)