8
8
9
9
import fastapi
10
10
import uvicorn
11
- from fastapi import Request
11
+ from fastapi import APIRouter , Request
12
12
from fastapi .exceptions import RequestValidationError
13
13
from fastapi .middleware .cors import CORSMiddleware
14
14
from fastapi .responses import JSONResponse , Response , StreamingResponse
35
35
from vllm .entrypoints .openai .serving_embedding import OpenAIServingEmbedding
36
36
from vllm .logger import init_logger
37
37
from vllm .usage .usage_lib import UsageContext
38
+ from vllm .utils import FlexibleArgumentParser
38
39
from vllm .version import __version__ as VLLM_VERSION
39
40
40
41
TIMEOUT_KEEP_ALIVE = 5 # seconds
41
42
43
+ logger = init_logger (__name__ )
44
+ engine : AsyncLLMEngine
45
+ engine_args : AsyncEngineArgs
42
46
openai_serving_chat : OpenAIServingChat
43
47
openai_serving_completion : OpenAIServingCompletion
44
48
openai_serving_embedding : OpenAIServingEmbedding
@@ -64,35 +68,23 @@ async def _force_log():
64
68
yield
65
69
66
70
67
- app = fastapi .FastAPI (lifespan = lifespan )
68
-
69
-
70
- def parse_args ():
71
- parser = make_arg_parser ()
72
- return parser .parse_args ()
73
-
71
+ router = APIRouter ()
74
72
75
73
# Add prometheus asgi middleware to route /metrics requests
76
74
route = Mount ("/metrics" , make_asgi_app ())
77
75
# Workaround for 307 Redirect for /metrics
78
76
route .path_regex = re .compile ('^/metrics(?P<path>.*)$' )
79
- app .routes .append (route )
80
-
81
-
82
- @app .exception_handler (RequestValidationError )
83
- async def validation_exception_handler (_ , exc ):
84
- err = openai_serving_chat .create_error_response (message = str (exc ))
85
- return JSONResponse (err .model_dump (), status_code = HTTPStatus .BAD_REQUEST )
77
+ router .routes .append (route )
86
78
87
79
88
- @app .get ("/health" )
80
+ @router .get ("/health" )
89
81
async def health () -> Response :
90
82
"""Health check."""
91
83
await openai_serving_chat .engine .check_health ()
92
84
return Response (status_code = 200 )
93
85
94
86
95
- @app .post ("/tokenize" )
87
+ @router .post ("/tokenize" )
96
88
async def tokenize (request : TokenizeRequest ):
97
89
generator = await openai_serving_completion .create_tokenize (request )
98
90
if isinstance (generator , ErrorResponse ):
@@ -103,7 +95,7 @@ async def tokenize(request: TokenizeRequest):
103
95
return JSONResponse (content = generator .model_dump ())
104
96
105
97
106
- @app .post ("/detokenize" )
98
+ @router .post ("/detokenize" )
107
99
async def detokenize (request : DetokenizeRequest ):
108
100
generator = await openai_serving_completion .create_detokenize (request )
109
101
if isinstance (generator , ErrorResponse ):
@@ -114,19 +106,19 @@ async def detokenize(request: DetokenizeRequest):
114
106
return JSONResponse (content = generator .model_dump ())
115
107
116
108
117
- @app .get ("/v1/models" )
109
+ @router .get ("/v1/models" )
118
110
async def show_available_models ():
119
111
models = await openai_serving_completion .show_available_models ()
120
112
return JSONResponse (content = models .model_dump ())
121
113
122
114
123
- @app .get ("/version" )
115
+ @router .get ("/version" )
124
116
async def show_version ():
125
117
ver = {"version" : VLLM_VERSION }
126
118
return JSONResponse (content = ver )
127
119
128
120
129
- @app .post ("/v1/chat/completions" )
121
+ @router .post ("/v1/chat/completions" )
130
122
async def create_chat_completion (request : ChatCompletionRequest ,
131
123
raw_request : Request ):
132
124
generator = await openai_serving_chat .create_chat_completion (
@@ -142,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
142
134
return JSONResponse (content = generator .model_dump ())
143
135
144
136
145
- @app .post ("/v1/completions" )
137
+ @router .post ("/v1/completions" )
146
138
async def create_completion (request : CompletionRequest , raw_request : Request ):
147
139
generator = await openai_serving_completion .create_completion (
148
140
request , raw_request )
@@ -156,7 +148,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
156
148
return JSONResponse (content = generator .model_dump ())
157
149
158
150
159
- @app .post ("/v1/embeddings" )
151
+ @router .post ("/v1/embeddings" )
160
152
async def create_embedding (request : EmbeddingRequest , raw_request : Request ):
161
153
generator = await openai_serving_embedding .create_embedding (
162
154
request , raw_request )
@@ -167,8 +159,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
167
159
return JSONResponse (content = generator .model_dump ())
168
160
169
161
170
- if __name__ == "__main__" :
171
- args = parse_args ()
162
+ def build_app (args ):
163
+ app = fastapi .FastAPI (lifespan = lifespan )
164
+ app .include_router (router )
165
+ app .root_path = args .root_path
172
166
173
167
app .add_middleware (
174
168
CORSMiddleware ,
@@ -178,6 +172,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
178
172
allow_headers = args .allowed_headers ,
179
173
)
180
174
175
+ @app .exception_handler (RequestValidationError )
176
+ async def validation_exception_handler (_ , exc ):
177
+ err = openai_serving_chat .create_error_response (message = str (exc ))
178
+ return JSONResponse (err .model_dump (),
179
+ status_code = HTTPStatus .BAD_REQUEST )
180
+
181
181
if token := envs .VLLM_API_KEY or args .api_key :
182
182
183
183
@app .middleware ("http" )
@@ -203,6 +203,12 @@ async def authentication(request: Request, call_next):
203
203
raise ValueError (f"Invalid middleware { middleware } . "
204
204
f"Must be a function or a class." )
205
205
206
+ return app
207
+
208
+
209
+ def run_server (args , llm_engine = None ):
210
+ app = build_app (args )
211
+
206
212
logger .info ("vLLM API server version %s" , VLLM_VERSION )
207
213
logger .info ("args: %s" , args )
208
214
@@ -211,10 +217,12 @@ async def authentication(request: Request, call_next):
211
217
else :
212
218
served_model_names = [args .model ]
213
219
214
- engine_args = AsyncEngineArgs . from_cli_args ( args )
220
+ global engine , engine_args
215
221
216
- engine = AsyncLLMEngine .from_engine_args (
217
- engine_args , usage_context = UsageContext .OPENAI_API_SERVER )
222
+ engine_args = AsyncEngineArgs .from_cli_args (args )
223
+ engine = (llm_engine
224
+ if llm_engine is not None else AsyncLLMEngine .from_engine_args (
225
+ engine_args , usage_context = UsageContext .OPENAI_API_SERVER ))
218
226
219
227
event_loop : Optional [asyncio .AbstractEventLoop ]
220
228
try :
@@ -230,6 +238,10 @@ async def authentication(request: Request, call_next):
230
238
# When using single vLLM without engine_use_ray
231
239
model_config = asyncio .run (engine .get_model_config ())
232
240
241
+ global openai_serving_chat
242
+ global openai_serving_completion
243
+ global openai_serving_embedding
244
+
233
245
openai_serving_chat = OpenAIServingChat (engine , model_config ,
234
246
served_model_names ,
235
247
args .response_role ,
@@ -258,3 +270,13 @@ async def authentication(request: Request, call_next):
258
270
ssl_certfile = args .ssl_certfile ,
259
271
ssl_ca_certs = args .ssl_ca_certs ,
260
272
ssl_cert_reqs = args .ssl_cert_reqs )
273
+
274
+
275
+ if __name__ == "__main__" :
276
+ # NOTE(simon):
277
+ # This section should be in sync with vllm/scripts.py for CLI entrypoints.
278
+ parser = FlexibleArgumentParser (
279
+ description = "vLLM OpenAI-Compatible RESTful API server." )
280
+ parser = make_arg_parser (parser )
281
+ args = parser .parse_args ()
282
+ run_server (args )
0 commit comments