Skip to content

Commit 0804384

Browse files
Yard1zhuohan123
andauthored
Start background task in AsyncLLMEngine.generate (#988)
Co-authored-by: Zhuohan Li <[email protected]>
1 parent 4b5bcf8 commit 0804384

File tree

4 files changed

+16
-24
lines changed

4 files changed

+16
-24
lines changed

tests/async_engine/api_server_async_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def stats() -> Response:
4040
args = parser.parse_args()
4141

4242
engine_args = AsyncEngineArgs.from_cli_args(args)
43-
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
44-
start_engine_loop=False)
43+
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
4544
vllm.entrypoints.api_server.engine = engine
4645
uvicorn.run(
4746
app,

vllm/engine/async_llm_engine.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ class AsyncLLMEngine:
230230
async frontend will be executed in a separate process as the
231231
model workers.
232232
log_requests: Whether to log the requests.
233+
start_engine_loop: If True, the background task to run the engine
234+
will be automatically started in the generate call.
233235
*args, *kwargs: Arguments for LLMEngine.
234236
"""
235237

@@ -240,7 +242,7 @@ def __init__(self,
240242
engine_use_ray: bool,
241243
*args,
242244
log_requests: bool = True,
243-
start_engine_loop: bool = False,
245+
start_engine_loop: bool = True,
244246
**kwargs) -> None:
245247
self.worker_use_ray = worker_use_ray
246248
self.engine_use_ray = engine_use_ray
@@ -249,8 +251,7 @@ def __init__(self,
249251

250252
self.request_tracker: RequestTracker = RequestTracker()
251253
self.background_loop = None
252-
if start_engine_loop:
253-
self.start_background_loop()
254+
self.start_engine_loop = start_engine_loop
254255

255256
@property
256257
def is_running(self) -> bool:
@@ -330,11 +331,14 @@ async def add_request(
330331
f"prompt token ids: {prompt_token_ids}.")
331332

332333
if not self.is_running:
333-
raise AsyncEngineDeadError(
334-
"Background loop is not running. If it was running, "
335-
"inspect the output to find the stacktrace of the "
336-
"error that caused the background loop to stop "
337-
"(AsyncEngineDeadError).")
334+
if self.start_engine_loop:
335+
self.start_background_loop()
336+
else:
337+
raise AsyncEngineDeadError(
338+
"Background loop is not running. If it was running, "
339+
"inspect the output to find the stacktrace of the "
340+
"error that caused the background loop to stop "
341+
"(AsyncEngineDeadError).")
338342

339343
stream = self.request_tracker.add_request(
340344
request_id,
@@ -426,7 +430,7 @@ async def get_model_config(self) -> ModelConfig:
426430
@classmethod
427431
def from_engine_args(cls,
428432
engine_args: AsyncEngineArgs,
429-
start_engine_loop: bool = False) -> "AsyncLLMEngine":
433+
start_engine_loop: bool = True) -> "AsyncLLMEngine":
430434
"""Creates an async LLM engine from the engine arguments."""
431435
# Create the engine configs.
432436
engine_configs = engine_args.create_engine_configs()

vllm/entrypoints/api_server.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ async def generate(request: Request) -> Response:
3232
sampling_params = SamplingParams(**request_dict)
3333
request_id = random_uuid()
3434

35-
if not engine.is_running:
36-
engine.start_background_loop()
37-
3835
results_generator = engine.generate(prompt, sampling_params, request_id)
3936

4037
# Streaming case
@@ -80,8 +77,7 @@ async def abort_request() -> None:
8077
args = parser.parse_args()
8178

8279
engine_args = AsyncEngineArgs.from_cli_args(args)
83-
engine = AsyncLLMEngine.from_engine_args(engine_args,
84-
start_engine_loop=False)
80+
engine = AsyncLLMEngine.from_engine_args(engine_args)
8581

8682
uvicorn.run(app,
8783
host=args.host,

vllm/entrypoints/openai/api_server.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
192192
"""
193193
logger.info(f"Received chat completion request: {request}")
194194

195-
if not engine.is_running:
196-
engine.start_background_loop()
197-
198195
error_check_ret = await check_model(request)
199196
if error_check_ret is not None:
200197
return error_check_ret
@@ -367,9 +364,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
367364
"""
368365
logger.info(f"Received completion request: {request}")
369366

370-
if not engine.is_running:
371-
engine.start_background_loop()
372-
373367
error_check_ret = await check_model(request)
374368
if error_check_ret is not None:
375369
return error_check_ret
@@ -627,8 +621,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
627621
served_model = args.model
628622

629623
engine_args = AsyncEngineArgs.from_cli_args(args)
630-
engine = AsyncLLMEngine.from_engine_args(engine_args,
631-
start_engine_loop=False)
624+
engine = AsyncLLMEngine.from_engine_args(engine_args)
632625
engine_model_config = asyncio.run(engine.get_model_config())
633626
max_model_len = engine_model_config.get_max_model_len()
634627

0 commit comments

Comments
 (0)