|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +# Adapted from: |
| 4 | +# https://github.com/vllm/vllm/entrypoints/openai/api_server.py |
| 5 | + |
| 6 | +import asyncio |
| 7 | +import signal |
| 8 | +import tempfile |
| 9 | +from argparse import Namespace |
| 10 | +from http import HTTPStatus |
| 11 | +from typing import Optional |
| 12 | + |
| 13 | +import uvloop |
| 14 | +from fastapi import APIRouter, Depends, FastAPI, Request |
| 15 | +from fastapi.middleware.cors import CORSMiddleware |
| 16 | +from fastapi.responses import JSONResponse, Response, StreamingResponse |
| 17 | +from starlette.datastructures import State |
| 18 | + |
| 19 | +import vllm.envs as envs |
| 20 | +from vllm.config import VllmConfig |
| 21 | +from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore |
| 22 | +from vllm.engine.protocol import EngineClient |
| 23 | +from vllm.entrypoints.anthropic.protocol import AnthropicErrorResponse, AnthropicMessagesRequest, \ |
| 24 | + AnthropicMessagesResponse |
| 25 | +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages |
| 26 | +from vllm.entrypoints.chat_utils import (load_chat_template, |
| 27 | + resolve_hf_chat_template, |
| 28 | + resolve_mistral_chat_template) |
| 29 | +from vllm.entrypoints.launcher import serve_http |
| 30 | +from vllm.entrypoints.logger import RequestLogger |
| 31 | +from vllm.entrypoints.openai.api_server import validate_api_server_args, create_server_socket, load_log_config, \ |
| 32 | + lifespan, build_async_engine_client, validate_json_request |
| 33 | +from vllm.entrypoints.openai.cli_args import (make_arg_parser, |
| 34 | + validate_parsed_serve_args) |
| 35 | +from vllm.entrypoints.openai.protocol import ErrorResponse |
| 36 | +from vllm.entrypoints.openai.serving_models import OpenAIServingModels, BaseModelPath, LoRAModulePath |
| 37 | +# |
| 38 | +# yapf: enable |
| 39 | +from vllm.entrypoints.openai.tool_parsers import ToolParserManager |
| 40 | +from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, |
| 41 | + with_cancellation) |
| 42 | +from vllm.logger import init_logger |
| 43 | +from vllm.transformers_utils.tokenizer import MistralTokenizer |
| 44 | +from vllm.utils import (FlexibleArgumentParser, |
| 45 | + is_valid_ipv6_address, |
| 46 | + set_ulimit) |
| 47 | +from vllm.version import __version__ as VLLM_VERSION |
| 48 | + |
| 49 | +prometheus_multiproc_dir: tempfile.TemporaryDirectory |
| 50 | + |
| 51 | +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) |
| 52 | +logger = init_logger('vllm.entrypoints.anthropic.api_server') |
| 53 | + |
| 54 | +_running_tasks: set[asyncio.Task] = set() |
| 55 | + |
| 56 | +router = APIRouter() |
| 57 | + |
| 58 | + |
| 59 | +def messages(request: Request) -> Optional[AnthropicServingMessages]: |
| 60 | + return request.app.state.anthropic_serving_messages |
| 61 | + |
| 62 | + |
| 63 | +def engine_client(request: Request) -> EngineClient: |
| 64 | + return request.app.state.engine_client |
| 65 | + |
| 66 | + |
| 67 | +@router.get("/health", response_class=Response) |
| 68 | +async def health(raw_request: Request) -> Response: |
| 69 | + """Health check.""" |
| 70 | + await engine_client(raw_request).check_health() |
| 71 | + return Response(status_code=200) |
| 72 | + |
| 73 | + |
| 74 | +@router.get("/ping", response_class=Response) |
| 75 | +@router.post("/ping", response_class=Response) |
| 76 | +async def ping(raw_request: Request) -> Response: |
| 77 | + """Ping check. Endpoint required for SageMaker""" |
| 78 | + return await health(raw_request) |
| 79 | + |
| 80 | + |
| 81 | +@router.post("/v1/messages", |
| 82 | + dependencies=[Depends(validate_json_request)], |
| 83 | + responses={ |
| 84 | + HTTPStatus.OK.value: { |
| 85 | + "content": { |
| 86 | + "text/event-stream": {} |
| 87 | + } |
| 88 | + }, |
| 89 | + HTTPStatus.BAD_REQUEST.value: { |
| 90 | + "model": AnthropicErrorResponse |
| 91 | + }, |
| 92 | + HTTPStatus.NOT_FOUND.value: { |
| 93 | + "model": AnthropicErrorResponse |
| 94 | + }, |
| 95 | + HTTPStatus.INTERNAL_SERVER_ERROR.value: { |
| 96 | + "model": AnthropicErrorResponse |
| 97 | + } |
| 98 | + }) |
| 99 | +@with_cancellation |
| 100 | +@load_aware_call |
| 101 | +async def create_messages(request: AnthropicMessagesRequest, |
| 102 | + raw_request: Request): |
| 103 | + handler = messages(raw_request) |
| 104 | + if handler is None: |
| 105 | + return messages(raw_request).create_error_response( |
| 106 | + message="The model does not support Chat Completions API") |
| 107 | + |
| 108 | + generator = await handler.create_messages(request, raw_request) |
| 109 | + |
| 110 | + if isinstance(generator, ErrorResponse): |
| 111 | + return JSONResponse(content=generator.model_dump(), |
| 112 | + status_code=generator.code) |
| 113 | + |
| 114 | + elif isinstance(generator, AnthropicMessagesResponse): |
| 115 | + return JSONResponse(content=generator.model_dump(exclude_none=True, exclude_unset=True)) |
| 116 | + |
| 117 | + return StreamingResponse(content=generator, media_type="text/event-stream") |
| 118 | + |
| 119 | + |
| 120 | +async def init_app_state( |
| 121 | + engine_client: EngineClient, |
| 122 | + vllm_config: VllmConfig, |
| 123 | + state: State, |
| 124 | + args: Namespace, |
| 125 | +) -> None: |
| 126 | + if args.served_model_name is not None: |
| 127 | + served_model_names = args.served_model_name |
| 128 | + else: |
| 129 | + served_model_names = [args.model] |
| 130 | + |
| 131 | + if args.disable_log_requests: |
| 132 | + request_logger = None |
| 133 | + else: |
| 134 | + request_logger = RequestLogger(max_log_len=args.max_log_len) |
| 135 | + |
| 136 | + base_model_paths = [ |
| 137 | + BaseModelPath(name=name, model_path=args.model) |
| 138 | + for name in served_model_names |
| 139 | + ] |
| 140 | + |
| 141 | + state.engine_client = engine_client |
| 142 | + state.log_stats = not args.disable_log_stats |
| 143 | + state.vllm_config = vllm_config |
| 144 | + model_config = vllm_config.model_config |
| 145 | + |
| 146 | + default_mm_loras = (vllm_config.lora_config.default_mm_loras |
| 147 | + if vllm_config.lora_config is not None else {}) |
| 148 | + lora_modules = args.lora_modules |
| 149 | + if default_mm_loras: |
| 150 | + default_mm_lora_paths = [ |
| 151 | + LoRAModulePath( |
| 152 | + name=modality, |
| 153 | + path=lora_path, |
| 154 | + ) for modality, lora_path in default_mm_loras.items() |
| 155 | + ] |
| 156 | + if args.lora_modules is None: |
| 157 | + lora_modules = default_mm_lora_paths |
| 158 | + else: |
| 159 | + lora_modules += default_mm_lora_paths |
| 160 | + |
| 161 | + resolved_chat_template = load_chat_template(args.chat_template) |
| 162 | + if resolved_chat_template is not None: |
| 163 | + # Get the tokenizer to check official template |
| 164 | + tokenizer = await engine_client.get_tokenizer() |
| 165 | + |
| 166 | + if isinstance(tokenizer, MistralTokenizer): |
| 167 | + # The warning is logged in resolve_mistral_chat_template. |
| 168 | + resolved_chat_template = resolve_mistral_chat_template( |
| 169 | + chat_template=resolved_chat_template) |
| 170 | + else: |
| 171 | + hf_chat_template = resolve_hf_chat_template( |
| 172 | + tokenizer=tokenizer, |
| 173 | + chat_template=None, |
| 174 | + tools=None, |
| 175 | + model_config=vllm_config.model_config, |
| 176 | + ) |
| 177 | + |
| 178 | + if hf_chat_template != resolved_chat_template: |
| 179 | + logger.warning( |
| 180 | + "Using supplied chat template: %s\n" |
| 181 | + "It is different from official chat template '%s'. " |
| 182 | + "This discrepancy may lead to performance degradation.", |
| 183 | + resolved_chat_template, args.model) |
| 184 | + |
| 185 | + state.openai_serving_models = OpenAIServingModels( |
| 186 | + engine_client=engine_client, |
| 187 | + model_config=model_config, |
| 188 | + base_model_paths=base_model_paths, |
| 189 | + lora_modules=lora_modules, |
| 190 | + ) |
| 191 | + await state.openai_serving_models.init_static_loras() |
| 192 | + state.anthropic_serving_messages = AnthropicServingMessages( |
| 193 | + engine_client, |
| 194 | + model_config, |
| 195 | + state.openai_serving_models, |
| 196 | + args.response_role, |
| 197 | + request_logger=request_logger, |
| 198 | + chat_template=resolved_chat_template, |
| 199 | + chat_template_content_format=args.chat_template_content_format, |
| 200 | + return_tokens_as_token_ids=args.return_tokens_as_token_ids, |
| 201 | + enable_auto_tools=args.enable_auto_tool_choice, |
| 202 | + tool_parser=args.tool_call_parser, |
| 203 | + reasoning_parser=args.reasoning_parser, |
| 204 | + enable_prompt_tokens_details=args.enable_prompt_tokens_details, |
| 205 | + enable_force_include_usage=args.enable_force_include_usage, |
| 206 | + ) |
| 207 | + |
| 208 | + |
| 209 | +def setup_server(args): |
| 210 | + """Validate API server args, set up signal handler, create socket |
| 211 | + ready to serve.""" |
| 212 | + |
| 213 | + logger.info("vLLM API server version %s", VLLM_VERSION) |
| 214 | + |
| 215 | + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: |
| 216 | + ToolParserManager.import_tool_parser(args.tool_parser_plugin) |
| 217 | + |
| 218 | + validate_api_server_args(args) |
| 219 | + |
| 220 | + # workaround to make sure that we bind the port before the engine is set up. |
| 221 | + # This avoids race conditions with ray. |
| 222 | + # see https://github.com/vllm-project/vllm/issues/8204 |
| 223 | + sock_addr = (args.host or "", args.port) |
| 224 | + sock = create_server_socket(sock_addr) |
| 225 | + |
| 226 | + # workaround to avoid footguns where uvicorn drops requests with too |
| 227 | + # many concurrent requests active |
| 228 | + set_ulimit() |
| 229 | + |
| 230 | + def signal_handler(*_) -> None: |
| 231 | + # Interrupt server on sigterm while initializing |
| 232 | + raise KeyboardInterrupt("terminated") |
| 233 | + |
| 234 | + signal.signal(signal.SIGTERM, signal_handler) |
| 235 | + |
| 236 | + addr, port = sock_addr |
| 237 | + is_ssl = args.ssl_keyfile and args.ssl_certfile |
| 238 | + host_part = f"[{addr}]" if is_valid_ipv6_address( |
| 239 | + addr) else addr or "0.0.0.0" |
| 240 | + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" |
| 241 | + |
| 242 | + return listen_address, sock |
| 243 | + |
| 244 | + |
| 245 | +async def run_server(args, **uvicorn_kwargs) -> None: |
| 246 | + """Run a single-worker API server.""" |
| 247 | + listen_address, sock = setup_server(args) |
| 248 | + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) |
| 249 | + |
| 250 | + |
| 251 | +def build_app(args: Namespace) -> FastAPI: |
| 252 | + app = FastAPI(lifespan=lifespan) |
| 253 | + app.include_router(router) |
| 254 | + app.root_path = args.root_path |
| 255 | + |
| 256 | + app.add_middleware( |
| 257 | + CORSMiddleware, |
| 258 | + allow_origins=args.allowed_origins, |
| 259 | + allow_credentials=args.allow_credentials, |
| 260 | + allow_methods=args.allowed_methods, |
| 261 | + allow_headers=args.allowed_headers, |
| 262 | + ) |
| 263 | + |
| 264 | + return app |
| 265 | + |
| 266 | + |
| 267 | +async def run_server_worker(listen_address, |
| 268 | + sock, |
| 269 | + args, |
| 270 | + client_config=None, |
| 271 | + **uvicorn_kwargs) -> None: |
| 272 | + """Run a single API server worker.""" |
| 273 | + |
| 274 | + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: |
| 275 | + ToolParserManager.import_tool_parser(args.tool_parser_plugin) |
| 276 | + |
| 277 | + server_index = client_config.get("client_index", 0) if client_config else 0 |
| 278 | + |
| 279 | + # Load logging config for uvicorn if specified |
| 280 | + log_config = load_log_config(args.log_config_file) |
| 281 | + if log_config is not None: |
| 282 | + uvicorn_kwargs['log_config'] = log_config |
| 283 | + |
| 284 | + async with build_async_engine_client( |
| 285 | + args, |
| 286 | + client_config=client_config, |
| 287 | + ) as engine_client: |
| 288 | + app = build_app(args) |
| 289 | + |
| 290 | + vllm_config = await engine_client.get_vllm_config() |
| 291 | + await init_app_state(engine_client, vllm_config, app.state, args) |
| 292 | + |
| 293 | + logger.info("Starting vLLM API server %d on %s", server_index, |
| 294 | + listen_address) |
| 295 | + shutdown_task = await serve_http( |
| 296 | + app, |
| 297 | + sock=sock, |
| 298 | + enable_ssl_refresh=args.enable_ssl_refresh, |
| 299 | + host=args.host, |
| 300 | + port=args.port, |
| 301 | + log_level=args.uvicorn_log_level, |
| 302 | + # NOTE: When the 'disable_uvicorn_access_log' value is True, |
| 303 | + # no access log will be output. |
| 304 | + access_log=not args.disable_uvicorn_access_log, |
| 305 | + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, |
| 306 | + ssl_keyfile=args.ssl_keyfile, |
| 307 | + ssl_certfile=args.ssl_certfile, |
| 308 | + ssl_ca_certs=args.ssl_ca_certs, |
| 309 | + ssl_cert_reqs=args.ssl_cert_reqs, |
| 310 | + **uvicorn_kwargs, |
| 311 | + ) |
| 312 | + |
| 313 | + # NB: Await server shutdown only after the backend context is exited |
| 314 | + try: |
| 315 | + await shutdown_task |
| 316 | + finally: |
| 317 | + sock.close() |
| 318 | + |
| 319 | + |
| 320 | +if __name__ == "__main__": |
| 321 | + # NOTE(simon): |
| 322 | + # This section should be in sync with vllm/entrypoints/cli/main.py for CLI |
| 323 | + # entrypoints. |
| 324 | + cli_env_setup() |
| 325 | + parser = FlexibleArgumentParser( |
| 326 | + description="vLLM Anthropic-Compatible RESTful API server.") |
| 327 | + parser = make_arg_parser(parser) |
| 328 | + args = parser.parse_args() |
| 329 | + validate_parsed_serve_args(args) |
| 330 | + |
| 331 | + uvloop.run(run_server(args)) |
0 commit comments