Skip to content

Commit 35abac7

Browse files
WIP - Stuck after shutdown signal received
1 parent 883593a commit 35abac7

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

src/guidellm/scheduler/worker.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,34 @@ async def resolve(
123123
...
124124

125125
async def get_request(
126-
self,
127-
requests_queue: multiprocessing.Queue,
126+
self, requests_queue: multiprocessing.Queue,
127+
shutdown_event: MultiprocessingEvent,
128+
process_id: int,
129+
shutdown_poll_interval_seconds: float,
128130
) -> Optional[WorkerProcessRequest[RequestT]]:
129-
return await asyncio.to_thread(requests_queue.get) # type: ignore[attr-defined]
131+
# We need to check shutdown_event intermittently cause
132+
# if we simply use asyncio.to_thread(requests_queue.get)
133+
# the cancellation task doesn't propagate because the
134+
# asyncio.to_thread is blocking
135+
return await asyncio.to_thread(requests_queue.get)
136+
# def _get_queue_intermittently():
137+
# while True:
138+
# try:
139+
# return requests_queue.get(timeout=shutdown_poll_interval_seconds)
140+
# except queue.Empty:
141+
# logger.info("Checking shutdown even is set in get_request")
142+
# if shutdown_event.is_set():
143+
# logger.info(f"Shutdown signal received in future {process_id}")
144+
# raise asyncio.CancelledError()
145+
# # return None
146+
#
147+
# try:
148+
# return await asyncio.to_thread(_get_queue_intermittently) # type: ignore[attr-defined]
149+
# except asyncio.CancelledError:
150+
# logger.info("kaki")
151+
# # return None
152+
# raise
153+
# # raise
130154

131155
async def send_result(
132156
self,
@@ -203,11 +227,15 @@ def run_process(
203227
max_concurrency: Optional[int] = None,
204228
):
205229
async def _process_runner():
230+
import threading
231+
internal_shutdown_event = threading.Event()
206232
if type_ == "sync":
207233
loop_task = asyncio.create_task(self._process_synchronous_requests_loop(
208234
requests_queue=requests_queue,
209235
results_queue=results_queue,
210236
process_id=process_id,
237+
shutdown_event=internal_shutdown_event,
238+
shutdown_poll_interval_seconds=shutdown_poll_interval_seconds,
211239
), name="request_loop_processor_task")
212240
elif type_ == "async":
213241
if max_concurrency is None:
@@ -218,6 +246,8 @@ async def _process_runner():
218246
results_queue=results_queue,
219247
max_concurrency=max_concurrency,
220248
process_id=process_id,
249+
shutdown_event=internal_shutdown_event,
250+
shutdown_poll_interval_seconds=shutdown_poll_interval_seconds,
221251
), name="request_loop_processor_task")
222252
else:
223253
raise ValueError(f"Invalid process type: {type_}")
@@ -237,10 +267,12 @@ async def _process_runner():
237267
],
238268
return_when=asyncio.FIRST_EXCEPTION,
239269
)
270+
logger.info("First exception happened")
240271

241272
for task in pending:
242273
logger.debug(f"Cancelling task {task.get_name()}")
243274
cancel_result = task.cancel()
275+
internal_shutdown_event.set()
244276
logger.debug(f"{'Task is already done or canceled' if not cancel_result else 'sent cancel signal'}")
245277
try:
246278
await task
@@ -271,18 +303,22 @@ async def _wait_for_shutdown(
271303
await asyncio.sleep(shutdown_poll_interval)
272304

273305
logger.debug("Shutdown signal received")
274-
raise ValueError("kaki")
275306
raise asyncio.CancelledError("Shutdown event set, cancelling process loop.")
276307

277308
async def _process_synchronous_requests_loop(
278309
self,
279310
requests_queue: multiprocessing.Queue,
280311
results_queue: multiprocessing.Queue,
281312
process_id: int,
313+
shutdown_event: MultiprocessingEvent,
314+
shutdown_poll_interval_seconds: float,
282315
):
283316
while True:
284317
process_request = await self.get_request(
285318
requests_queue=requests_queue,
319+
shutdown_event=shutdown_event,
320+
process_id=process_id,
321+
shutdown_poll_interval_seconds=shutdown_poll_interval_seconds
286322
)
287323

288324
dequeued_time = time.time()
@@ -297,15 +333,14 @@ async def _process_synchronous_requests_loop(
297333
process_id=process_id,
298334
)
299335

300-
logger.debug("Done processing synchronous loop")
301-
302-
303336
async def _process_asynchronous_requests_loop(
304337
self,
305338
requests_queue: multiprocessing.Queue,
306339
results_queue: multiprocessing.Queue,
307340
max_concurrency: int,
308341
process_id: int,
342+
shutdown_event: MultiprocessingEvent,
343+
shutdown_poll_interval_seconds: float,
309344
):
310345
pending = asyncio.Semaphore(max_concurrency)
311346

@@ -316,6 +351,9 @@ async def _process_asynchronous_requests_loop(
316351
logger.info("Awaiting request...")
317352
process_request = await self.get_request(
318353
requests_queue=requests_queue,
354+
shutdown_event=shutdown_event,
355+
process_id=process_id,
356+
shutdown_poll_interval_seconds=shutdown_poll_interval_seconds,
319357
)
320358

321359
dequeued_time = time.time()
@@ -351,8 +389,6 @@ def _task_done(_: asyncio.Task):
351389
task.add_done_callback(_task_done)
352390
await asyncio.sleep(0) # enable start task immediately
353391

354-
logger.debug("Done processing asynchronous loop")
355-
356392

357393
class GenerativeRequestsWorkerDescription(WorkerDescription):
358394
type_: Literal["generative_requests_worker"] = "generative_requests_worker" # type: ignore[assignment]

0 commit comments

Comments
 (0)