Skip to content

Commit cc08fc7

Browse files
[Frontend] Reapply "Factor out code for running uvicorn" (#7095)
1 parent 7b86e7c commit cc08fc7

File tree

3 files changed

+125
-82
lines changed

3 files changed

+125
-82
lines changed

vllm/entrypoints/api_server.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,23 @@
55
We are also not going to accept PRs modifying this file, please
66
change `vllm/entrypoints/openai/api_server.py` instead.
77
"""
8-
8+
import asyncio
99
import json
1010
import ssl
11-
from typing import AsyncGenerator
11+
from argparse import Namespace
12+
from typing import Any, AsyncGenerator, Optional
1213

13-
import uvicorn
1414
from fastapi import FastAPI, Request
1515
from fastapi.responses import JSONResponse, Response, StreamingResponse
1616

1717
from vllm.engine.arg_utils import AsyncEngineArgs
1818
from vllm.engine.async_llm_engine import AsyncLLMEngine
19+
from vllm.entrypoints.launcher import serve_http
1920
from vllm.logger import init_logger
2021
from vllm.sampling_params import SamplingParams
2122
from vllm.usage.usage_lib import UsageContext
2223
from vllm.utils import FlexibleArgumentParser, random_uuid
24+
from vllm.version import __version__ as VLLM_VERSION
2325

2426
logger = init_logger("vllm.entrypoints.api_server")
2527

@@ -81,6 +83,53 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
8183
return JSONResponse(ret)
8284

8385

86+
def build_app(args: Namespace) -> FastAPI:
87+
global app
88+
89+
app.root_path = args.root_path
90+
return app
91+
92+
93+
async def init_app(
94+
args: Namespace,
95+
llm_engine: Optional[AsyncLLMEngine] = None,
96+
) -> FastAPI:
97+
app = build_app(args)
98+
99+
global engine
100+
101+
engine_args = AsyncEngineArgs.from_cli_args(args)
102+
engine = (llm_engine
103+
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
104+
engine_args, usage_context=UsageContext.API_SERVER))
105+
106+
return app
107+
108+
109+
async def run_server(args: Namespace,
110+
llm_engine: Optional[AsyncLLMEngine] = None,
111+
**uvicorn_kwargs: Any) -> None:
112+
logger.info("vLLM API server version %s", VLLM_VERSION)
113+
logger.info("args: %s", args)
114+
115+
app = await init_app(args, llm_engine)
116+
117+
shutdown_task = await serve_http(
118+
app,
119+
host=args.host,
120+
port=args.port,
121+
log_level=args.log_level,
122+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
123+
ssl_keyfile=args.ssl_keyfile,
124+
ssl_certfile=args.ssl_certfile,
125+
ssl_ca_certs=args.ssl_ca_certs,
126+
ssl_cert_reqs=args.ssl_cert_reqs,
127+
**uvicorn_kwargs,
128+
)
129+
130+
await shutdown_task
131+
132+
84133
if __name__ == "__main__":
85134
parser = FlexibleArgumentParser()
86135
parser.add_argument("--host", type=str, default=None)
@@ -105,25 +154,5 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
105154
parser.add_argument("--log-level", type=str, default="debug")
106155
parser = AsyncEngineArgs.add_cli_args(parser)
107156
args = parser.parse_args()
108-
engine_args = AsyncEngineArgs.from_cli_args(args)
109-
engine = AsyncLLMEngine.from_engine_args(
110-
engine_args, usage_context=UsageContext.API_SERVER)
111-
112-
app.root_path = args.root_path
113157

114-
logger.info("Available routes are:")
115-
for route in app.routes:
116-
if not hasattr(route, 'methods'):
117-
continue
118-
methods = ', '.join(route.methods)
119-
logger.info("Route: %s, Methods: %s", route.path, methods)
120-
121-
uvicorn.run(app,
122-
host=args.host,
123-
port=args.port,
124-
log_level=args.log_level,
125-
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
126-
ssl_keyfile=args.ssl_keyfile,
127-
ssl_certfile=args.ssl_certfile,
128-
ssl_ca_certs=args.ssl_ca_certs,
129-
ssl_cert_reqs=args.ssl_cert_reqs)
158+
asyncio.run(run_server(args))

vllm/entrypoints/launcher.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import asyncio
2+
import signal
3+
from typing import Any
4+
5+
import uvicorn
6+
from fastapi import FastAPI
7+
8+
from vllm.logger import init_logger
9+
10+
logger = init_logger(__name__)
11+
12+
13+
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
14+
logger.info("Available routes are:")
15+
for route in app.routes:
16+
methods = getattr(route, "methods", None)
17+
path = getattr(route, "path", None)
18+
19+
if methods is None or path is None:
20+
continue
21+
22+
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
23+
24+
config = uvicorn.Config(app, **uvicorn_kwargs)
25+
server = uvicorn.Server(config)
26+
27+
loop = asyncio.get_running_loop()
28+
29+
server_task = loop.create_task(server.serve())
30+
31+
def signal_handler() -> None:
32+
# prevents the uvicorn signal handler to exit early
33+
server_task.cancel()
34+
35+
async def dummy_shutdown() -> None:
36+
pass
37+
38+
loop.add_signal_handler(signal.SIGINT, signal_handler)
39+
loop.add_signal_handler(signal.SIGTERM, signal_handler)
40+
41+
try:
42+
await server_task
43+
return dummy_shutdown()
44+
except asyncio.CancelledError:
45+
logger.info("Gracefully stopping http server")
46+
return server.shutdown()

vllm/entrypoints/openai/api_server.py

Lines changed: 26 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22
import importlib
33
import inspect
44
import re
5-
import signal
5+
from argparse import Namespace
66
from contextlib import asynccontextmanager
77
from http import HTTPStatus
88
from multiprocessing import Process
99
from typing import AsyncIterator, Set
1010

11-
import fastapi
12-
import uvicorn
13-
from fastapi import APIRouter, Request
11+
from fastapi import APIRouter, FastAPI, Request
1412
from fastapi.exceptions import RequestValidationError
1513
from fastapi.middleware.cors import CORSMiddleware
1614
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -22,6 +20,7 @@
2220
from vllm.engine.arg_utils import AsyncEngineArgs
2321
from vllm.engine.async_llm_engine import AsyncLLMEngine
2422
from vllm.engine.protocol import AsyncEngineClient
23+
from vllm.entrypoints.launcher import serve_http
2524
from vllm.entrypoints.logger import RequestLogger
2625
from vllm.entrypoints.openai.cli_args import make_arg_parser
2726
# yapf conflicts with isort for this block
@@ -71,7 +70,7 @@ def model_is_embedding(model_name: str) -> bool:
7170

7271

7372
@asynccontextmanager
74-
async def lifespan(app: fastapi.FastAPI):
73+
async def lifespan(app: FastAPI):
7574

7675
async def _force_log():
7776
while True:
@@ -135,7 +134,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
135134
router = APIRouter()
136135

137136

138-
def mount_metrics(app: fastapi.FastAPI):
137+
def mount_metrics(app: FastAPI):
139138
# Add prometheus asgi middleware to route /metrics requests
140139
metrics_route = Mount("/metrics", make_asgi_app())
141140
# Workaround for 307 Redirect for /metrics
@@ -225,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
225224
return JSONResponse(content=generator.model_dump())
226225

227226

228-
def build_app(args):
229-
app = fastapi.FastAPI(lifespan=lifespan)
227+
def build_app(args: Namespace) -> FastAPI:
228+
app = FastAPI(lifespan=lifespan)
230229
app.include_router(router)
231230
app.root_path = args.root_path
232231

@@ -274,11 +273,10 @@ async def authentication(request: Request, call_next):
274273
return app
275274

276275

277-
async def build_server(
276+
async def init_app(
278277
async_engine_client: AsyncEngineClient,
279-
args,
280-
**uvicorn_kwargs,
281-
) -> uvicorn.Server:
278+
args: Namespace,
279+
) -> FastAPI:
282280
app = build_app(args)
283281

284282
if args.served_model_name is not None:
@@ -334,62 +332,31 @@ async def build_server(
334332
)
335333
app.root_path = args.root_path
336334

337-
logger.info("Available routes are:")
338-
for route in app.routes:
339-
if not hasattr(route, 'methods'):
340-
continue
341-
methods = ', '.join(route.methods)
342-
logger.info("Route: %s, Methods: %s", route.path, methods)
343-
344-
config = uvicorn.Config(
345-
app,
346-
host=args.host,
347-
port=args.port,
348-
log_level=args.uvicorn_log_level,
349-
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
350-
ssl_keyfile=args.ssl_keyfile,
351-
ssl_certfile=args.ssl_certfile,
352-
ssl_ca_certs=args.ssl_ca_certs,
353-
ssl_cert_reqs=args.ssl_cert_reqs,
354-
**uvicorn_kwargs,
355-
)
356-
357-
return uvicorn.Server(config)
335+
return app
358336

359337

360338
async def run_server(args, **uvicorn_kwargs) -> None:
361339
logger.info("vLLM API server version %s", VLLM_VERSION)
362340
logger.info("args: %s", args)
363341

364-
shutdown_task = None
365342
async with build_async_engine_client(args) as async_engine_client:
366-
367-
server = await build_server(
368-
async_engine_client,
369-
args,
343+
app = await init_app(async_engine_client, args)
344+
345+
shutdown_task = await serve_http(
346+
app,
347+
host=args.host,
348+
port=args.port,
349+
log_level=args.uvicorn_log_level,
350+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
351+
ssl_keyfile=args.ssl_keyfile,
352+
ssl_certfile=args.ssl_certfile,
353+
ssl_ca_certs=args.ssl_ca_certs,
354+
ssl_cert_reqs=args.ssl_cert_reqs,
370355
**uvicorn_kwargs,
371356
)
372357

373-
loop = asyncio.get_running_loop()
374-
375-
server_task = loop.create_task(server.serve())
376-
377-
def signal_handler() -> None:
378-
# prevents the uvicorn signal handler to exit early
379-
server_task.cancel()
380-
381-
loop.add_signal_handler(signal.SIGINT, signal_handler)
382-
loop.add_signal_handler(signal.SIGTERM, signal_handler)
383-
384-
try:
385-
await server_task
386-
except asyncio.CancelledError:
387-
logger.info("Gracefully stopping http server")
388-
shutdown_task = server.shutdown()
389-
390-
if shutdown_task:
391-
# NB: Await server shutdown only after the backend context is exited
392-
await shutdown_task
358+
# NB: Await server shutdown only after the backend context is exited
359+
await shutdown_task
393360

394361

395362
if __name__ == "__main__":
@@ -399,4 +366,5 @@ def signal_handler() -> None:
399366
description="vLLM OpenAI-Compatible RESTful API server.")
400367
parser = make_arg_parser(parser)
401368
args = parser.parse_args()
369+
402370
asyncio.run(run_server(args))

0 commit comments

Comments
 (0)