|
1 | 1 | import asyncio |
2 | 2 | import code |
3 | | -import json |
4 | 3 | import os |
5 | | -import signal |
6 | 4 | import subprocess |
7 | 5 | import traceback |
8 | 6 | from logging import Logger |
9 | | -from typing import AsyncGenerator, Dict, List, Optional |
10 | 7 |
|
11 | | -import vllm.envs as envs |
12 | | -from fastapi import APIRouter, BackgroundTasks, Request |
13 | | -from fastapi.responses import Response, StreamingResponse |
14 | | -from vllm.engine.async_llm_engine import AsyncEngineDeadError |
15 | 8 | from vllm.engine.protocol import EngineClient |
16 | | -from vllm.entrypoints.launcher import serve_http |
17 | | -from vllm.entrypoints.openai.api_server import ( |
18 | | - build_app, |
19 | | - build_async_engine_client, |
20 | | - init_app_state, |
21 | | - load_log_config, |
22 | | - maybe_register_tokenizer_info_endpoint, |
23 | | - setup_server, |
24 | | -) |
| 9 | +from vllm.entrypoints.openai.api_server import run_server |
25 | 10 | from vllm.entrypoints.openai.cli_args import make_arg_parser |
26 | | -from vllm.entrypoints.openai.tool_parsers import ToolParserManager |
27 | | -from vllm.outputs import CompletionOutput |
28 | | -from vllm.sampling_params import SamplingParams |
29 | | -from vllm.sequence import Logprob |
30 | | -from vllm.utils import FlexibleArgumentParser, random_uuid |
| 11 | +from vllm.utils import FlexibleArgumentParser |
31 | 12 |
|
32 | 13 | logger = Logger("vllm_server") |
33 | 14 |
|
|
36 | 17 | TIMEOUT_KEEP_ALIVE = 5 # seconds. |
37 | 18 | TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds |
38 | 19 |
|
39 | | -router = APIRouter() |
40 | | - |
41 | | - |
42 | | -@router.post("/predict") |
43 | | -@router.post("/stream") |
44 | | -async def generate(request: Request) -> Response: |
45 | | - """Generate completion for the request. |
46 | | -
|
47 | | - The request should be a JSON object with the following fields: |
48 | | - - prompt: the prompt to use for the generation. |
49 | | - - stream: whether to stream the results or not. |
50 | | - - other fields: the sampling parameters (See `SamplingParams` for details). |
51 | | - """ |
52 | | - # check health before accepting request and fail fast if engine isn't healthy |
53 | | - try: |
54 | | - await engine_client.check_health() |
55 | | - |
56 | | - request_dict = await request.json() |
57 | | - prompt = request_dict.pop("prompt") |
58 | | - stream = request_dict.pop("stream", False) |
59 | | - |
60 | | - sampling_params = SamplingParams(**request_dict) |
61 | | - |
62 | | - request_id = random_uuid() |
63 | | - |
64 | | - results_generator = engine_client.generate(prompt, sampling_params, request_id) |
65 | | - |
66 | | - async def abort_request() -> None: |
67 | | - await engine_client.abort(request_id) |
68 | | - |
69 | | - if stream: |
70 | | - # Streaming case |
71 | | - async def stream_results() -> AsyncGenerator[str, None]: |
72 | | - last_output_text = "" |
73 | | - async for request_output in results_generator: |
74 | | - log_probs = format_logprobs(request_output) |
75 | | - ret = { |
76 | | - "text": request_output.outputs[-1].text[len(last_output_text) :], |
77 | | - "count_prompt_tokens": len(request_output.prompt_token_ids), |
78 | | - "count_output_tokens": len(request_output.outputs[0].token_ids), |
79 | | - "log_probs": ( |
80 | | - log_probs[-1] if log_probs and sampling_params.logprobs else None |
81 | | - ), |
82 | | - "finished": request_output.finished, |
83 | | - } |
84 | | - last_output_text = request_output.outputs[-1].text |
85 | | - yield f"data:{json.dumps(ret)}\n\n" |
86 | | - |
87 | | - background_tasks = BackgroundTasks() |
88 | | - # Abort the request if the client disconnects. |
89 | | - background_tasks.add_task(abort_request) |
90 | | - |
91 | | - return StreamingResponse(stream_results(), background=background_tasks) |
92 | | - |
93 | | - # Non-streaming case |
94 | | - final_output = None |
95 | | - tokens = [] |
96 | | - last_output_text = "" |
97 | | - async for request_output in results_generator: |
98 | | - tokens.append(request_output.outputs[-1].text[len(last_output_text) :]) |
99 | | - last_output_text = request_output.outputs[-1].text |
100 | | - if await request.is_disconnected(): |
101 | | - # Abort the request if the client disconnects. |
102 | | - await engine_client.abort(request_id) |
103 | | - return Response(status_code=499) |
104 | | - final_output = request_output |
105 | | - |
106 | | - assert final_output is not None |
107 | | - prompt = final_output.prompt |
108 | | - ret = { |
109 | | - "text": final_output.outputs[0].text, |
110 | | - "count_prompt_tokens": len(final_output.prompt_token_ids), |
111 | | - "count_output_tokens": len(final_output.outputs[0].token_ids), |
112 | | - "log_probs": format_logprobs(final_output), |
113 | | - "tokens": tokens, |
114 | | - } |
115 | | - return Response(content=json.dumps(ret)) |
116 | | - |
117 | | - except AsyncEngineDeadError as e: |
118 | | - logger.error(f"The vllm engine is dead, exiting the pod: {e}") |
119 | | - os.kill(os.getpid(), signal.SIGINT) |
120 | | - raise e |
| 20 | +# Legacy endpoints /predit and /stream removed - using vLLM's native OpenAI-compatible endpoints instead |
| 21 | +# All requests now go through /v1/completions, /v1/chat/completions, etc. |
121 | 22 |
|
122 | 23 |
|
123 | 24 | def get_gpu_free_memory(): |
@@ -171,90 +72,18 @@ def debug(sig, frame): |
171 | 72 | i.interact(message) |
172 | 73 |
|
173 | 74 |
|
174 | | -def format_logprobs( |
175 | | - request_output: CompletionOutput, |
176 | | -) -> Optional[List[Dict[int, float]]]: |
177 | | - """Given a request output, format the logprobs if they exist.""" |
178 | | - output_logprobs = request_output.outputs[0].logprobs |
179 | | - if output_logprobs is None: |
180 | | - return None |
181 | | - |
182 | | - def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: |
183 | | - return {k: v.logprob for k, v in logprobs.items()} |
184 | | - |
185 | | - return [extract_logprobs(logprobs) for logprobs in output_logprobs] |
186 | | - |
187 | | - |
188 | 75 | def parse_args(parser: FlexibleArgumentParser): |
189 | 76 | parser = make_arg_parser(parser) |
190 | 77 | parser.add_argument("--attention-backend", type=str, help="The attention backend to use") |
191 | 78 | return parser.parse_args() |
192 | 79 |
|
193 | 80 |
|
194 | | -async def run_server(args, **uvicorn_kwargs) -> None: |
195 | | - """Run a single-worker API server.""" |
196 | | - listen_address, sock = setup_server(args) |
197 | | - await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) |
198 | | - |
199 | | - |
200 | | -async def run_server_worker( |
201 | | - listen_address, sock, args, client_config=None, **uvicorn_kwargs |
202 | | -) -> None: |
203 | | - """Run a single API server worker.""" |
204 | | - |
205 | | - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: |
206 | | - ToolParserManager.import_tool_parser(args.tool_parser_plugin) |
207 | | - |
208 | | - server_index = client_config.get("client_index", 0) if client_config else 0 |
209 | | - |
210 | | - # Load logging config for uvicorn if specified |
211 | | - log_config = load_log_config(args.log_config_file) |
212 | | - if log_config is not None: |
213 | | - uvicorn_kwargs["log_config"] = log_config |
214 | | - |
215 | | - global engine_client |
216 | | - |
217 | | - async with build_async_engine_client(args, client_config=client_config) as engine_client: |
218 | | - maybe_register_tokenizer_info_endpoint(args) |
219 | | - app = build_app(args) |
220 | | - |
221 | | - vllm_config = await engine_client.get_vllm_config() |
222 | | - await init_app_state(engine_client, vllm_config, app.state, args) |
223 | | - app.include_router(router) |
224 | | - |
225 | | - logger.info("Starting vLLM API server %d on %s", server_index, listen_address) |
226 | | - shutdown_task = await serve_http( |
227 | | - app, |
228 | | - sock=sock, |
229 | | - enable_ssl_refresh=args.enable_ssl_refresh, |
230 | | - host=args.host, |
231 | | - port=args.port, |
232 | | - log_level=args.uvicorn_log_level, |
233 | | - # NOTE: When the 'disable_uvicorn_access_log' value is True, |
234 | | - # no access log will be output. |
235 | | - access_log=not args.disable_uvicorn_access_log, |
236 | | - timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, |
237 | | - ssl_keyfile=args.ssl_keyfile, |
238 | | - ssl_certfile=args.ssl_certfile, |
239 | | - ssl_ca_certs=args.ssl_ca_certs, |
240 | | - ssl_cert_reqs=args.ssl_cert_reqs, |
241 | | - h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, |
242 | | - h11_max_header_count=args.h11_max_header_count, |
243 | | - **uvicorn_kwargs, |
244 | | - ) |
245 | | - |
246 | | - # NB: Await server shutdown only after the backend context is exited |
247 | | - try: |
248 | | - await shutdown_task |
249 | | - finally: |
250 | | - sock.close() |
251 | | - |
252 | | - |
253 | 81 | if __name__ == "__main__": |
254 | 82 | check_unknown_startup_memory_usage() |
255 | 83 |
|
256 | 84 | parser = FlexibleArgumentParser() |
257 | 85 | args = parse_args(parser) |
258 | 86 | if args.attention_backend is not None: |
259 | 87 | os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_backend |
| 88 | + # Using vllm's run_server |
260 | 89 | asyncio.run(run_server(args)) |
0 commit comments