Skip to content

Commit 0e8713c

Browse files
committed
Add wait_then_requeue behavior
Signed-off-by: Samuel Monson <[email protected]>
1 parent 1de1c64 commit 0e8713c

File tree

1 file changed

+49
-13
lines changed

1 file changed

+49
-13
lines changed

src/guidellm/scheduler/worker.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import time
1414
from multiprocessing.synchronize import Barrier as ProcessingBarrier
1515
from multiprocessing.synchronize import Event as ProcessingEvent
16-
from typing import Annotated, Generic, Literal
16+
from typing import Annotated, Generic, Literal, TypeAliasType
1717

1818
try:
1919
import uvloop
@@ -50,6 +50,16 @@
5050

5151
__all__ = ["WorkerProcess"]
5252

53+
ProcessRequestT = TypeAliasType(
54+
"ProcessRequestT",
55+
tuple[
56+
HistoryT[RequestT, ResponseT],
57+
MultiTurnT[RequestT],
58+
ScheduledRequestAugmentation,
59+
],
60+
type_params=(RequestT, ResponseT),
61+
)
62+
5363

5464
class WorkerProcess(Generic[RequestT, ResponseT]):
5565
"""
@@ -271,12 +281,20 @@ async def _process_requests_loop(self):
271281
async_semaphore = asyncio.Semaphore(self.async_limit)
272282
pending_tasks: set[asyncio.Task] = set()
273283

274-
def _task_done(task):
284+
def _task_done(task: asyncio.Task[ProcessRequestT[RequestT, ResponseT]]):
275285
pending_tasks.discard(task)
276286
async_semaphore.release()
277287

278-
if not task.cancelled() and (exception := task.exception()):
279-
raise exception
288+
if not task.cancelled():
289+
if exception := task.exception():
290+
raise exception
291+
292+
history, conversation, aug = task.result()
293+
if conversation:
294+
requeue_task = asyncio.create_task(
295+
self._wait_then_requeue(history, conversation, aug)
296+
)
297+
pending_tasks.add(requeue_task)
280298

281299
# Main loop; loop until canceled
282300
while True:
@@ -313,12 +331,14 @@ async def _cancel_requests_loop(self):
313331
request_info.scheduler_timings.resolve_end = time.time()
314332
self._send_update("cancelled", None, request, request_info)
315333

316-
async def _process_next_request(self):
317-
conversation: MultiTurnT[RequestT] | None = None
334+
async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]:
335+
conversation: MultiTurnT[RequestT] = []
336+
history: HistoryT[RequestT, ResponseT] = []
318337
request: RequestT | None = None
319338
request_info: ScheduledRequestInfo | None = None
320339
response: ResponseT | None = None
321340
aug: ScheduledRequestAugmentation | None = None
341+
premature_exit: bool = False
322342

323343
try:
324344
# Pull request from the queue
@@ -359,32 +379,48 @@ async def _process_next_request(self):
359379
request_info.scheduler_timings.resolve_end = time.time()
360380
self._send_update("completed", response, request, request_info)
361381

362-
# If multi-turn, queue up next turn(s)
363-
# TODO: Move to callback and support delay
364-
if conversation: # more turns to process
365-
history.append((request, response))
366-
self.turns_queue.append((history, conversation))
382+
# Record Turn
383+
history.append((request, response))
367384

368-
response = request = request_info = conversation = None
385+
response = request = request_info = None
369386
except asyncio.CancelledError:
387+
premature_exit = True
370388
# Handle cancellation
371389
if request is not None and request_info is not None:
372390
request_info.error = "Request was cancelled"
373391
request_info.scheduler_timings.resolve_end = time.time()
374392
self._send_update("cancelled", response, request, request_info)
375393
raise
376394
except Exception as exc: # noqa: BLE001
395+
premature_exit = True
377396
if request is not None and request_info is not None:
378397
request_info.error = str(exc)
379398
request_info.scheduler_timings.resolve_end = time.time()
380399
self._send_update("errored", response, request, request_info)
381400
finally:
382-
if conversation is not None:
401+
if premature_exit and conversation:
383402
for request, _, request_info in conversation:
384403
request_info.error = "Request was cancelled"
385404
request_info.scheduler_timings.resolve_end = time.time()
386405
self._send_update("cancelled", response, request, request_info)
387406

407+
return history, conversation, aug
408+
409+
async def _wait_then_requeue(
410+
self,
411+
history: HistoryT[RequestT, ResponseT],
412+
conversation: MultiTurnT[RequestT],
413+
aug: ScheduledRequestAugmentation,
414+
):
415+
try:
416+
if aug.post_requeue_delay > 0:
417+
await asyncio.sleep(aug.post_requeue_delay)
418+
except asyncio.CancelledError:
419+
# If we are cancelled, dump straight to queue
420+
raise
421+
finally:
422+
self.turns_queue.append((history, conversation))
423+
388424
def _send_update(
389425
self,
390426
new_status: Literal[

0 commit comments

Comments
 (0)