Skip to content

Commit c9927c1

Browse files
authored
Use queue for finished requests (#957)
1 parent fbd80ad commit c9927c1

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def generate_greedy(
156156
) -> List[Tuple[List[int], str]]:
157157
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
158158
outputs = self.generate(prompts, greedy_params)
159-
return [(output_ids[0], output_str[0]) for output_ids, output_str in
160-
outputs]
159+
return [(output_ids[0], output_str[0])
160+
for output_ids, output_str in outputs]
161161

162162
def generate_beam_search(
163163
self,

vllm/engine/async_llm_engine.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import time
33
from functools import partial
4-
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
4+
from typing import Any, Dict, Iterable, List, Optional, Type, Union
55

66
from vllm.config import ModelConfig
77
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -152,7 +152,7 @@ def __init__(self,
152152

153153
# Request id -> stream.
154154
self.request_streams: Dict[str, AsyncStream] = {}
155-
self.finished_requests: Set[str] = set()
155+
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
156156
self.background_loop = None
157157
if start_engine_loop:
158158
self.start_background_loop()
@@ -194,12 +194,14 @@ async def engine_step(self):
194194
if self.log_requests:
195195
logger.info(f"Finished request {request_id}.")
196196
self.request_streams[request_id].finish()
197-
self.finished_requests.add(request_id)
197+
self.finished_requests.put_nowait(request_id)
198198

199-
await self._engine_abort(self.finished_requests)
200-
for request_id in self.finished_requests:
199+
finished_request = set()
200+
while not self.finished_requests.empty():
201+
finished_request.add(self.finished_requests.get_nowait())
202+
await self._engine_abort(finished_request)
203+
for request_id in finished_request:
201204
del self.request_streams[request_id]
202-
self.finished_requests.clear()
203205

204206
async def _engine_abort(self, request_ids: Iterable[str]):
205207
if self.engine_use_ray:
@@ -226,6 +228,8 @@ async def add_request(
226228
f"sampling params: {sampling_params}, "
227229
f"prompt token ids: {prompt_token_ids}.")
228230

231+
if request_id in self.request_streams:
232+
raise KeyError(f"Request {request_id} already exists.")
229233
stream = AsyncStream(request_id)
230234
self.request_streams[request_id] = stream
231235

@@ -316,7 +320,7 @@ def _abort(self, request_id: str) -> None:
316320
logger.info(f"Aborted request {request_id}.")
317321

318322
self.request_streams[request_id].finish()
319-
self.finished_requests.add(request_id)
323+
self.finished_requests.put_nowait(request_id)
320324

321325
async def get_model_config(self) -> ModelConfig:
322326
"""Get the model configuration of the vLLM engine."""

0 commit comments

Comments
 (0)