Skip to content

Commit dbfe254

Browse files
EthanqXsimon-mo
andauthored
[Feature] vLLM CLI (#5090)
Co-authored-by: simon-mo <[email protected]>
1 parent 73030b7 commit dbfe254

File tree

7 files changed

+223
-36
lines changed

7 files changed

+223
-36
lines changed

benchmarks/benchmark_serving.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
33
On the server side, run one of the following commands:
44
vLLM OpenAI API server
5-
python -m vllm.entrypoints.openai.api_server \
6-
--model <your_model> --swap-space 16 \
5+
vllm serve <your_model> \
6+
--swap-space 16 \
77
--disable-log-requests
88
99
(TGI backend)

docs/source/serving/openai_compatible_server.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
109109

110110
```{argparse}
111111
:module: vllm.entrypoints.openai.cli_args
112-
:func: make_arg_parser
112+
:func: create_parser_for_docs
113113
:prog: -m vllm.entrypoints.openai.api_server
114114
```
115115

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,4 +488,9 @@ def _read_requirements(filename: str) -> List[str]:
488488
},
489489
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
490490
package_data=package_data,
491+
entry_points={
492+
"console_scripts": [
493+
"vllm=vllm.scripts:main",
494+
],
495+
},
491496
)

tests/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.distributed import (ensure_model_parallel_initialized,
1515
init_distributed_environment)
1616
from vllm.entrypoints.openai.cli_args import make_arg_parser
17-
from vllm.utils import get_open_port, is_hip
17+
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
1818

1919
if is_hip():
2020
from amdsmi import (amdsmi_get_gpu_vram_usage,
@@ -57,7 +57,9 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:
5757

5858
cli_args = cli_args + ["--port", str(get_open_port())]
5959

60-
parser = make_arg_parser()
60+
parser = FlexibleArgumentParser(
61+
description="vLLM's remote OpenAI server.")
62+
parser = make_arg_parser(parser)
6163
args = parser.parse_args(cli_args)
6264
self.host = str(args.host or 'localhost')
6365
self.port = int(args.port)

vllm/entrypoints/openai/api_server.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import fastapi
1010
import uvicorn
11-
from fastapi import Request
11+
from fastapi import APIRouter, Request
1212
from fastapi.exceptions import RequestValidationError
1313
from fastapi.middleware.cors import CORSMiddleware
1414
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -35,10 +35,14 @@
3535
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
3636
from vllm.logger import init_logger
3737
from vllm.usage.usage_lib import UsageContext
38+
from vllm.utils import FlexibleArgumentParser
3839
from vllm.version import __version__ as VLLM_VERSION
3940

4041
TIMEOUT_KEEP_ALIVE = 5 # seconds
4142

43+
logger = init_logger(__name__)
44+
engine: AsyncLLMEngine
45+
engine_args: AsyncEngineArgs
4246
openai_serving_chat: OpenAIServingChat
4347
openai_serving_completion: OpenAIServingCompletion
4448
openai_serving_embedding: OpenAIServingEmbedding
@@ -64,35 +68,23 @@ async def _force_log():
6468
yield
6569

6670

67-
app = fastapi.FastAPI(lifespan=lifespan)
68-
69-
70-
def parse_args():
71-
parser = make_arg_parser()
72-
return parser.parse_args()
73-
71+
router = APIRouter()
7472

7573
# Add prometheus asgi middleware to route /metrics requests
7674
route = Mount("/metrics", make_asgi_app())
7775
# Workaround for 307 Redirect for /metrics
7876
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
79-
app.routes.append(route)
80-
81-
82-
@app.exception_handler(RequestValidationError)
83-
async def validation_exception_handler(_, exc):
84-
err = openai_serving_chat.create_error_response(message=str(exc))
85-
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
77+
router.routes.append(route)
8678

8779

88-
@app.get("/health")
80+
@router.get("/health")
8981
async def health() -> Response:
9082
"""Health check."""
9183
await openai_serving_chat.engine.check_health()
9284
return Response(status_code=200)
9385

9486

95-
@app.post("/tokenize")
87+
@router.post("/tokenize")
9688
async def tokenize(request: TokenizeRequest):
9789
generator = await openai_serving_completion.create_tokenize(request)
9890
if isinstance(generator, ErrorResponse):
@@ -103,7 +95,7 @@ async def tokenize(request: TokenizeRequest):
10395
return JSONResponse(content=generator.model_dump())
10496

10597

106-
@app.post("/detokenize")
98+
@router.post("/detokenize")
10799
async def detokenize(request: DetokenizeRequest):
108100
generator = await openai_serving_completion.create_detokenize(request)
109101
if isinstance(generator, ErrorResponse):
@@ -114,19 +106,19 @@ async def detokenize(request: DetokenizeRequest):
114106
return JSONResponse(content=generator.model_dump())
115107

116108

117-
@app.get("/v1/models")
109+
@router.get("/v1/models")
118110
async def show_available_models():
119111
models = await openai_serving_completion.show_available_models()
120112
return JSONResponse(content=models.model_dump())
121113

122114

123-
@app.get("/version")
115+
@router.get("/version")
124116
async def show_version():
125117
ver = {"version": VLLM_VERSION}
126118
return JSONResponse(content=ver)
127119

128120

129-
@app.post("/v1/chat/completions")
121+
@router.post("/v1/chat/completions")
130122
async def create_chat_completion(request: ChatCompletionRequest,
131123
raw_request: Request):
132124
generator = await openai_serving_chat.create_chat_completion(
@@ -142,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
142134
return JSONResponse(content=generator.model_dump())
143135

144136

145-
@app.post("/v1/completions")
137+
@router.post("/v1/completions")
146138
async def create_completion(request: CompletionRequest, raw_request: Request):
147139
generator = await openai_serving_completion.create_completion(
148140
request, raw_request)
@@ -156,7 +148,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
156148
return JSONResponse(content=generator.model_dump())
157149

158150

159-
@app.post("/v1/embeddings")
151+
@router.post("/v1/embeddings")
160152
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
161153
generator = await openai_serving_embedding.create_embedding(
162154
request, raw_request)
@@ -167,8 +159,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
167159
return JSONResponse(content=generator.model_dump())
168160

169161

170-
if __name__ == "__main__":
171-
args = parse_args()
162+
def build_app(args):
163+
app = fastapi.FastAPI(lifespan=lifespan)
164+
app.include_router(router)
165+
app.root_path = args.root_path
172166

173167
app.add_middleware(
174168
CORSMiddleware,
@@ -178,6 +172,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
178172
allow_headers=args.allowed_headers,
179173
)
180174

175+
@app.exception_handler(RequestValidationError)
176+
async def validation_exception_handler(_, exc):
177+
err = openai_serving_chat.create_error_response(message=str(exc))
178+
return JSONResponse(err.model_dump(),
179+
status_code=HTTPStatus.BAD_REQUEST)
180+
181181
if token := envs.VLLM_API_KEY or args.api_key:
182182

183183
@app.middleware("http")
@@ -203,6 +203,12 @@ async def authentication(request: Request, call_next):
203203
raise ValueError(f"Invalid middleware {middleware}. "
204204
f"Must be a function or a class.")
205205

206+
return app
207+
208+
209+
def run_server(args, llm_engine=None):
210+
app = build_app(args)
211+
206212
logger.info("vLLM API server version %s", VLLM_VERSION)
207213
logger.info("args: %s", args)
208214

@@ -211,10 +217,12 @@ async def authentication(request: Request, call_next):
211217
else:
212218
served_model_names = [args.model]
213219

214-
engine_args = AsyncEngineArgs.from_cli_args(args)
220+
global engine, engine_args
215221

216-
engine = AsyncLLMEngine.from_engine_args(
217-
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
222+
engine_args = AsyncEngineArgs.from_cli_args(args)
223+
engine = (llm_engine
224+
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
225+
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
218226

219227
event_loop: Optional[asyncio.AbstractEventLoop]
220228
try:
@@ -230,6 +238,10 @@ async def authentication(request: Request, call_next):
230238
# When using single vLLM without engine_use_ray
231239
model_config = asyncio.run(engine.get_model_config())
232240

241+
global openai_serving_chat
242+
global openai_serving_completion
243+
global openai_serving_embedding
244+
233245
openai_serving_chat = OpenAIServingChat(engine, model_config,
234246
served_model_names,
235247
args.response_role,
@@ -258,3 +270,13 @@ async def authentication(request: Request, call_next):
258270
ssl_certfile=args.ssl_certfile,
259271
ssl_ca_certs=args.ssl_ca_certs,
260272
ssl_cert_reqs=args.ssl_cert_reqs)
273+
274+
275+
if __name__ == "__main__":
276+
# NOTE(simon):
277+
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
278+
parser = FlexibleArgumentParser(
279+
description="vLLM OpenAI-Compatible RESTful API server.")
280+
parser = make_arg_parser(parser)
281+
args = parser.parse_args()
282+
run_server(args)

vllm/entrypoints/openai/cli_args.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def __call__(self, parser, namespace, values, option_string=None):
3434
setattr(namespace, self.dest, adapter_list)
3535

3636

37-
def make_arg_parser():
38-
parser = FlexibleArgumentParser(
39-
description="vLLM OpenAI-Compatible RESTful API server.")
37+
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
4038
parser.add_argument("--host",
4139
type=nullable_str,
4240
default=None,
@@ -133,3 +131,9 @@ def make_arg_parser():
133131

134132
parser = AsyncEngineArgs.add_cli_args(parser)
135133
return parser
134+
135+
136+
def create_parser_for_docs() -> FlexibleArgumentParser:
137+
parser_for_docs = FlexibleArgumentParser(
138+
prog="-m vllm.entrypoints.openai.api_server")
139+
return make_arg_parser(parser_for_docs)

0 commit comments

Comments
 (0)