|
87 | 87 |
|
88 | 88 | import argparse
|
89 | 89 | import asyncio
|
| 90 | +import functools |
90 | 91 | import heapq
|
91 | 92 | import os
|
92 | 93 | import sys
|
| 94 | +import uuid |
93 | 95 | from contextlib import asynccontextmanager
|
94 | 96 | from typing import List
|
95 | 97 |
|
@@ -137,7 +139,6 @@ def __init__(self, prefiller_instances, decoder_instances):
|
137 | 139 | ]
|
138 | 140 | self.req_to_prefiller = {}
|
139 | 141 | self.req_id_lock = asyncio.Lock()
|
140 |
| - self.req_id_counter = 0 |
141 | 142 | # Removed selection locks - no longer needed for synchronous methods
|
142 | 143 |
|
143 | 144 | # Initialize priority queues for efficient server selection
|
@@ -193,8 +194,7 @@ def aquire_aborted_prefiller_requests(
|
193 | 194 |
|
194 | 195 | async def next_req_id(self):
|
195 | 196 | async with self.req_id_lock:
|
196 |
| - self.req_id_counter += 1 |
197 |
| - return str(self.req_id_counter) |
| 197 | + return str(uuid.uuid4()) |
198 | 198 |
|
199 | 199 | def select_prefiller(self, token_count): # Changed to synchronous
|
200 | 200 | # No lock needed - entire function is atomic
|
@@ -313,6 +313,32 @@ async def lifespan(app: FastAPI):
|
313 | 313 | await d.client.aclose()
|
314 | 314 |
|
315 | 315 |
|
| 316 | +async def listen_for_disconnect(request: Request) -> None: |
| 317 | + """Return if a disconnect message is received""" |
| 318 | + while True: |
| 319 | + message = await request.receive() |
| 320 | + if message["type"] == "http.disconnect": |
| 321 | + break |
| 322 | + |
| 323 | + |
| 324 | +def with_cancellation(handler_func): |
| 325 | + |
| 326 | + @functools.wraps(handler_func) |
| 327 | + async def wrapper(*args, **kwargs): |
| 328 | + request = kwargs["request"] |
| 329 | + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) |
| 330 | + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) |
| 331 | + done, pending = await asyncio.wait([handler_task, cancellation_task], |
| 332 | + return_when=asyncio.FIRST_COMPLETED) |
| 333 | + for task in pending: |
| 334 | + task.cancel() |
| 335 | + if handler_task in done: |
| 336 | + return handler_task.result() |
| 337 | + return None |
| 338 | + |
| 339 | + return wrapper |
| 340 | + |
| 341 | + |
316 | 342 | app = FastAPI(lifespan=lifespan)
|
317 | 343 |
|
318 | 344 |
|
@@ -493,11 +519,13 @@ async def generate_stream():
|
493 | 519 |
|
494 | 520 |
|
495 | 521 | @app.post("/v1/completions")
|
| 522 | +@with_cancellation |
496 | 523 | async def handle_completions(request: Request):
|
497 | 524 | return await _handle_completions("/completions", request)
|
498 | 525 |
|
499 | 526 |
|
500 | 527 | @app.post("/v1/chat/completions")
|
| 528 | +@with_cancellation |
501 | 529 | async def handle_chat_completions(request: Request):
|
502 | 530 | return await _handle_completions("/chat/completions", request)
|
503 | 531 |
|
|
0 commit comments