|
13 | 13 | import time
|
14 | 14 | from multiprocessing.synchronize import Barrier as ProcessingBarrier
|
15 | 15 | from multiprocessing.synchronize import Event as ProcessingEvent
|
16 |
| -from typing import Annotated, Generic, Literal |
| 16 | +from typing import Annotated, Generic, Literal, TypeAliasType |
17 | 17 |
|
18 | 18 | try:
|
19 | 19 | import uvloop
|
|
50 | 50 |
|
51 | 51 | __all__ = ["WorkerProcess"]
|
52 | 52 |
|
| 53 | +ProcessRequestT = TypeAliasType( |
| 54 | + "ProcessRequestT", |
| 55 | + tuple[ |
| 56 | + HistoryT[RequestT, ResponseT], |
| 57 | + MultiTurnT[RequestT], |
| 58 | + ScheduledRequestAugmentation, |
| 59 | + ], |
| 60 | + type_params=(RequestT, ResponseT), |
| 61 | +) |
| 62 | + |
53 | 63 |
|
54 | 64 | class WorkerProcess(Generic[RequestT, ResponseT]):
|
55 | 65 | """
|
@@ -271,12 +281,20 @@ async def _process_requests_loop(self):
|
271 | 281 | async_semaphore = asyncio.Semaphore(self.async_limit)
|
272 | 282 | pending_tasks: set[asyncio.Task] = set()
|
273 | 283 |
|
274 |
| - def _task_done(task): |
| 284 | + def _task_done(task: asyncio.Task[ProcessRequestT[RequestT, ResponseT]]): |
275 | 285 | pending_tasks.discard(task)
|
276 | 286 | async_semaphore.release()
|
277 | 287 |
|
278 |
| - if not task.cancelled() and (exception := task.exception()): |
279 |
| - raise exception |
| 288 | + if not task.cancelled(): |
| 289 | + if exception := task.exception(): |
| 290 | + raise exception |
| 291 | + |
| 292 | + history, conversation, aug = task.result() |
| 293 | + if conversation: |
| 294 | + requeue_task = asyncio.create_task( |
| 295 | + self._wait_then_requeue(history, conversation, aug) |
| 296 | + ) |
| 297 | + pending_tasks.add(requeue_task) |
280 | 298 |
|
281 | 299 | # Main loop; loop until canceled
|
282 | 300 | while True:
|
@@ -313,12 +331,14 @@ async def _cancel_requests_loop(self):
|
313 | 331 | request_info.scheduler_timings.resolve_end = time.time()
|
314 | 332 | self._send_update("cancelled", None, request, request_info)
|
315 | 333 |
|
316 |
| - async def _process_next_request(self): |
317 |
| - conversation: MultiTurnT[RequestT] | None = None |
| 334 | + async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: |
| 335 | + conversation: MultiTurnT[RequestT] = [] |
| 336 | + history: HistoryT[RequestT, ResponseT] = [] |
318 | 337 | request: RequestT | None = None
|
319 | 338 | request_info: ScheduledRequestInfo | None = None
|
320 | 339 | response: ResponseT | None = None
|
321 | 340 | aug: ScheduledRequestAugmentation | None = None
|
| 341 | + premature_exit: bool = False |
322 | 342 |
|
323 | 343 | try:
|
324 | 344 | # Pull request from the queue
|
@@ -359,32 +379,48 @@ async def _process_next_request(self):
|
359 | 379 | request_info.scheduler_timings.resolve_end = time.time()
|
360 | 380 | self._send_update("completed", response, request, request_info)
|
361 | 381 |
|
362 |
| - # If multi-turn, queue up next turn(s) |
363 |
| - # TODO: Move to callback and support delay |
364 |
| - if conversation: # more turns to process |
365 |
| - history.append((request, response)) |
366 |
| - self.turns_queue.append((history, conversation)) |
| 382 | + # Record Turn |
| 383 | + history.append((request, response)) |
367 | 384 |
|
368 |
| - response = request = request_info = conversation = None |
| 385 | + response = request = request_info = None |
369 | 386 | except asyncio.CancelledError:
|
| 387 | + premature_exit = True |
370 | 388 | # Handle cancellation
|
371 | 389 | if request is not None and request_info is not None:
|
372 | 390 | request_info.error = "Request was cancelled"
|
373 | 391 | request_info.scheduler_timings.resolve_end = time.time()
|
374 | 392 | self._send_update("cancelled", response, request, request_info)
|
375 | 393 | raise
|
376 | 394 | except Exception as exc: # noqa: BLE001
|
| 395 | + premature_exit = True |
377 | 396 | if request is not None and request_info is not None:
|
378 | 397 | request_info.error = str(exc)
|
379 | 398 | request_info.scheduler_timings.resolve_end = time.time()
|
380 | 399 | self._send_update("errored", response, request, request_info)
|
381 | 400 | finally:
|
382 |
| - if conversation is not None: |
| 401 | + if premature_exit and conversation: |
383 | 402 | for request, _, request_info in conversation:
|
384 | 403 | request_info.error = "Request was cancelled"
|
385 | 404 | request_info.scheduler_timings.resolve_end = time.time()
|
386 | 405 | self._send_update("cancelled", response, request, request_info)
|
387 | 406 |
|
| 407 | + return history, conversation, aug |
| 408 | + |
| 409 | + async def _wait_then_requeue( |
| 410 | + self, |
| 411 | + history: HistoryT[RequestT, ResponseT], |
| 412 | + conversation: MultiTurnT[RequestT], |
| 413 | + aug: ScheduledRequestAugmentation, |
| 414 | + ): |
| 415 | + try: |
| 416 | + if aug.post_requeue_delay > 0: |
| 417 | + await asyncio.sleep(aug.post_requeue_delay) |
| 418 | + except asyncio.CancelledError: |
| 419 | + # If we are cancelled, dump straight to queue |
| 420 | + raise |
| 421 | + finally: |
| 422 | + self.turns_queue.append((history, conversation)) |
| 423 | + |
388 | 424 | def _send_update(
|
389 | 425 | self,
|
390 | 426 | new_status: Literal[
|
|
0 commit comments