Skip to content

Commit 0bacaa5

Browse files
authored
Update vllm server to be openai compatible (#560)
* Update vllm engine to be openai compatible * Bump vllm to 0.5.1 * Revert 0.5.1 -- need some CUDA version upgrade * Small cleanup
1 parent d5d9193 commit 0bacaa5

File tree

1 file changed

+92
-11
lines changed

1 file changed

+92
-11
lines changed

model-engine/model_engine_server/inference/vllm/vllm_server.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,86 @@
11
import argparse
2+
import asyncio
23
import code
34
import json
5+
import logging
46
import os
57
import signal
68
import subprocess
79
import traceback
10+
from logging import Logger
811
from typing import AsyncGenerator, Dict, List, Optional
912

1013
import uvicorn
1114
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
12-
from fastapi.responses import Response, StreamingResponse
15+
from fastapi.responses import JSONResponse, Response, StreamingResponse
1316
from vllm.engine.arg_utils import AsyncEngineArgs
1417
from vllm.engine.async_llm_engine import AsyncEngineDeadError, AsyncLLMEngine
18+
from vllm.entrypoints.openai.cli_args import make_arg_parser
19+
from vllm.entrypoints.openai.protocol import ChatCompletionRequest as OpenAIChatCompletionRequest
20+
from vllm.entrypoints.openai.protocol import ChatCompletionResponse as OpenAIChatCompletionResponse
1521
from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest
22+
from vllm.entrypoints.openai.protocol import ErrorResponse
23+
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
24+
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
25+
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
1626
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
1727
from vllm.outputs import CompletionOutput
1828
from vllm.sampling_params import SamplingParams
1929
from vllm.sequence import Logprob
2030
from vllm.utils import random_uuid
31+
from vllm.version import __version__ as VLLM_VERSION
32+
33+
logging.basicConfig(
34+
format="%(asctime)s | %(levelname)s: %(message)s",
35+
datefmt="%b/%d %H:%M:%S",
36+
level=logging.INFO,
37+
)
38+
39+
logger = Logger("vllm_server")
2140

2241
TIMEOUT_KEEP_ALIVE = 5 # seconds.
2342
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
2443
app = FastAPI()
2544

45+
openai_serving_chat: OpenAIServingChat
46+
openai_serving_completion: OpenAIServingCompletion
47+
openai_serving_embedding: OpenAIServingEmbedding
48+
2649

2750
@app.get("/healthz")
2851
@app.get("/health")
29-
def healthcheck():
30-
return "OK"
52+
async def healthcheck():
53+
await openai_serving_chat.engine.check_health()
54+
return Response(status_code=200)
55+
56+
57+
@app.get("/v1/models")
58+
async def show_available_models():
59+
models = await openai_serving_chat.show_available_models()
60+
return JSONResponse(content=models.model_dump())
61+
62+
63+
@app.post("/v1/chat/completions")
64+
async def create_chat_completion(request: OpenAIChatCompletionRequest, raw_request: Request):
65+
generator = await openai_serving_chat.create_chat_completion(request, raw_request)
66+
if isinstance(generator, ErrorResponse):
67+
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
68+
if request.stream:
69+
return StreamingResponse(content=generator, media_type="text/event-stream")
70+
else:
71+
assert isinstance(generator, OpenAIChatCompletionResponse)
72+
return JSONResponse(content=generator.model_dump())
73+
74+
75+
@app.post("/v1/completions")
76+
async def create_completion(request: OpenAICompletionRequest, raw_request: Request):
77+
generator = await openai_serving_completion.create_completion(request, raw_request)
78+
if isinstance(generator, ErrorResponse):
79+
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
80+
if request.stream:
81+
return StreamingResponse(content=generator, media_type="text/event-stream")
82+
else:
83+
return JSONResponse(content=generator.model_dump())
3184

3285

3386
@app.post("/predict")
@@ -135,7 +188,7 @@ async def stream_results() -> AsyncGenerator[str, None]:
135188
return Response(content=json.dumps(ret))
136189

137190
except AsyncEngineDeadError as e:
138-
print(f"The vllm engine is dead, exiting the pod: {e}")
191+
logger.error(f"The vllm engine is dead, exiting the pod: {e}")
139192
os.kill(os.getpid(), signal.SIGINT)
140193
raise e
141194

@@ -151,7 +204,7 @@ def get_gpu_free_memory():
151204
gpu_memory = [int(x) for x in output.strip().split("\n")]
152205
return gpu_memory
153206
except Exception as e:
154-
print(f"Error getting GPU memory: {e}")
207+
logger.warn(f"Error getting GPU memory: {e}")
155208
return None
156209

157210

@@ -162,17 +215,17 @@ def check_unknown_startup_memory_usage():
162215
min_mem = min(gpu_free_memory)
163216
max_mem = max(gpu_free_memory)
164217
if max_mem - min_mem > 10:
165-
print(
218+
logger.warn(
166219
f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}."
167220
)
168221
try:
169222
# nosemgrep
170223
output = subprocess.run(
171224
["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True
172225
).stdout
173-
print(f"Processes using GPU: {output}")
226+
logger.info(f"Processes using GPU: {output}")
174227
except Exception as e:
175-
print(f"Error getting processes using GPU: {e}")
228+
logger.error(f"Error getting processes using GPU: {e}")
176229

177230

178231
def debug(sig, frame):
@@ -200,23 +253,51 @@ def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]:
200253
return [extract_logprobs(logprobs) for logprobs in output_logprobs]
201254

202255

256+
def parse_args():
257+
parser = make_arg_parser()
258+
return parser.parse_args()
259+
260+
203261
if __name__ == "__main__":
204262
check_unknown_startup_memory_usage()
263+
205264
parser = argparse.ArgumentParser()
206265
parser.add_argument("--host", type=str, default=None) # None == IPv4 / IPv6 dualstack
207266
parser.add_argument("--port", type=int, default=5005)
208267
parser = AsyncEngineArgs.add_cli_args(parser)
209-
args = parser.parse_args()
268+
args = parse_args()
269+
270+
logger.info("vLLM version %s", VLLM_VERSION)
271+
logger.info("args: %s", args)
272+
273+
if args.served_model_name is not None:
274+
served_model_names = args.served_model_name
275+
else:
276+
served_model_names = [args.model]
277+
278+
signal.signal(signal.SIGUSR1, debug)
210279

211280
engine_args = AsyncEngineArgs.from_cli_args(args)
212281
engine = AsyncLLMEngine.from_engine_args(engine_args)
213282

214-
signal.signal(signal.SIGUSR1, debug)
283+
model_config = asyncio.run(engine.get_model_config())
284+
285+
openai_serving_chat = OpenAIServingChat(
286+
engine,
287+
model_config,
288+
served_model_names,
289+
args.response_role,
290+
args.lora_modules,
291+
args.chat_template,
292+
)
293+
openai_serving_completion = OpenAIServingCompletion(
294+
engine, model_config, served_model_names, args.lora_modules
295+
)
215296

216297
uvicorn.run(
217298
app,
218299
host=args.host,
219300
port=args.port,
220-
log_level="debug",
301+
log_level=args.uvicorn_log_level,
221302
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
222303
)

0 commit comments

Comments
 (0)