diff --git a/tests/entrypoints/anthropic/__init__.py b/tests/entrypoints/anthropic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/anthropic/test_messages.py b/tests/entrypoints/anthropic/test_messages.py new file mode 100644 index 000000000000..2f0521a85b32 --- /dev/null +++ b/tests/entrypoints/anthropic/test_messages.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import pytest_asyncio +import anthropic +from ...utils import RemoteAnthropicServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + "--max-model-len", "8192", "--enforce-eager", + "--enable-auto-tool-choice", "--tool-call-parser", "hermes", + "--served-model-name", "claude-3-7-sonnet-latest" + ] + + with RemoteAnthropicServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_simple_messages(client: anthropic.Anthropic): + resp = client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=8192, + messages=[ + { + "role": "user", + "content": "how are you!" + } + ], + ) + assert resp.stop_reason == "end_turn" + assert resp.role == "assistant" + + print(f"Anthropic response: {resp.model_dump_json()}") + + +@pytest.mark.asyncio +def test_system_message(client: anthropic.Anthropic): + resp = client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=8192, + system="you are a helpful assistant", + messages=[ + { + "role": "user", + "content": "how are you!" + } + ], + ) + assert resp.stop_reason == "end_turn" + assert resp.role == "assistant" + + print(f"Anthropic response: {resp.model_dump_json()}") + + +@pytest.mark.asyncio +def test_anthropic_streaming(client: anthropic.Anthropic): + resp = client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=8192, + messages=[ + { + "role": "user", + "content": "how are you!" + } + ], + stream=True, + ) + assert resp.stop_reason == "end_turn" + assert resp.role == "assistant" + + for chunk in resp: + print(chunk.model_dump_json()) + + +@pytest.mark.asyncio +def test_anthropic_tool_call(client: anthropic.Anthropic): + resp = client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=8192, + messages=[ + { + "role": "user", + "content": "What's the weather like in New York today?" + } + ], + tools=[ + { + "name": "get_current_weather", + "description": "Useful for querying the weather in a specified city.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or region, for example: New York, London, Tokyo, etc." + } + }, + "required": ["location"] + } + } + + ], + stream=False, + ) + assert resp.stop_reason == "tool_use" + assert resp.role == "assistant" + + print(f'Anthropic response: {resp.model_dump_json()}') + + @pytest.mark.asyncio + def test_anthropic_tool_call_streaming(client: anthropic.Anthropic): + resp = client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=8192, + messages=[ + { + "role": "user", + "content": "What's the weather like in New York today?" + } + ], + tools=[ + { + "name": "get_current_weather", + "description": "Useful for querying the weather in a specified city.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or region, for example: New York, London, Tokyo, etc." + } + }, + "required": ["location"] + } + } + + ], + stream=True, + ) + + for chunk in resp: + print(chunk.model_dump_json()) diff --git a/tests/entrypoints/anthropic/test_serving_messages.py b/tests/entrypoints/anthropic/test_serving_messages.py new file mode 100644 index 000000000000..2281150408df --- /dev/null +++ b/tests/entrypoints/anthropic/test_serving_messages.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from contextlib import suppress +from dataclasses import dataclass, field +from typing import Any, Optional +from unittest.mock import MagicMock + +import pytest + +from vllm.config import MultiModalConfig +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.entrypoints.anthropic.protocol import AnthropicMessagesRequest +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL_NAME = "openai-community/gpt2" +CHAT_TEMPLATE = "Dummy chat template for testing {}" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + task = "generate" + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() + logits_processor_pattern = None + diff_sampling_param: Optional[dict] = None + allowed_local_media_path: str = "" + encoder_config = None + generation_config: str = "auto" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} + + +@dataclass +class MockEngine: + + async def get_model_config(self): + return MockModelConfig() + + +async def _async_serving_chat_init(): + engine = MockEngine() + model_config = await engine.get_model_config() + + models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS) + serving_completion = AnthropicServingMessages(engine, + model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + return serving_completion + + +def test_async_serving_chat_init(): + serving_completion = asyncio.run(_async_serving_chat_init()) + assert serving_completion.chat_template == CHAT_TEMPLATE + + +@pytest.mark.asyncio +async def test_serving_chat_should_set_correct_max_tokens(): + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=MockModelConfig()) + serving_messages = AnthropicServingMessages(mock_engine, + MockModelConfig(), + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + req = AnthropicMessagesRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + ) + + with suppress(Exception): + await serving_messages.create_messages(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + req.max_tokens = 10 + with suppress(Exception): + await serving_messages.create_messages(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 + + # Setting server's max_tokens in the generation_config.json + # lower than context_window - prompt_tokens + mock_model_config = MockModelConfig() + mock_model_config.diff_sampling_param = { + "max_tokens": 10 # Setting server-side max_tokens limit + } + + # Reinitialize the engine with new settings + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + serving_chat = AnthropicServingMessages(mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + # Test Case 1: No max_tokens specified in request + req = AnthropicMessagesRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + ) + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 + + # Test Case 2: Request's max_tokens set higher than server accepts + req.max_tokens = 15 + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 + + # Test Case 3: Request's max_tokens set lower than server accepts + req.max_tokens = 5 + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 5 + + # Setting server's max_tokens in the generation_config.json + # higher than context_window - prompt_tokens + mock_model_config = MockModelConfig() + mock_model_config.diff_sampling_param = { + "max_tokens": 200 # Setting server-side max_tokens limit + } + + # Reinitialize the engine with new settings + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + serving_chat = AnthropicServingMessages(mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + # Test case 1: No max_tokens specified, defaults to context_window + req = AnthropicMessagesRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + ) + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + # Test Case 2: Request's max_tokens set higher than server accepts + req.max_tokens = 100 + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + # Test Case 3: Request's max_tokens set lower than server accepts + req.max_tokens = 5 + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 5 + + +@pytest.mark.asyncio +async def test_serving_chat_could_load_correct_generation_config(): + + mock_model_config = MockModelConfig() + mock_model_config.diff_sampling_param = { + "temperature": 0.5, + "repetition_penalty": 1.05 + } + + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + serving_chat = AnthropicServingMessages(mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + req = AnthropicMessagesRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + ) + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].temperature == 0.5 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + # Test the param when user set it + req.temperature = 0.1 + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].temperature == 0.1 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + # Test When temperature==0.0 + req.temperature = 0.0 + + with suppress(Exception): + await serving_chat.create_messages(req) + + assert mock_engine.generate.call_args.args[1].temperature == 0.0 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + diff --git a/tests/utils.py b/tests/utils.py index 1c1a1cc6014e..ce5d088aa342 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,6 +16,7 @@ from pathlib import Path from typing import Any, Callable, Literal, Optional, Union +import anthropic import cloudpickle import openai import pytest @@ -194,6 +195,130 @@ def get_async_client(self, **kwargs): **kwargs) +class RemoteAnthropicServer: + DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key + def __init__(self, + model: str, + vllm_serve_args: list[str], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None) -> None: + if auto_port: + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: + raise ValueError("You have manually specified the port " + "when `auto_port=True`.") + + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + [ + "--port", str(get_open_port()) + ] + if seed is not None: + if "--seed" in vllm_serve_args: + raise ValueError("You have manually specified the seed " + f"when `seed={seed}`.") + + vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + + parser = FlexibleArgumentParser( + description="vLLM's remote Anthropic server.") + subparsers = parser.add_subparsers(required=False, dest="subparser") + parser = ServeSubcommand().subparser_init(subparsers) + args = parser.parse_args(["--model", model, *vllm_serve_args]) + self.host = str(args.host or 'localhost') + self.port = int(args.port) + + self.show_hidden_metrics = \ + args.show_hidden_metrics_for_version is not None + + # download the model before starting the server to avoid timeout + is_local = os.path.isdir(model) + if not is_local: + engine_args = AsyncEngineArgs.from_cli_args(args) + model_config = engine_args.create_model_config() + load_config = engine_args.create_load_config() + + model_loader = get_model_loader(load_config) + model_loader.download_model(model_config) + + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) + self.proc = subprocess.Popen( + ["python -m", "vllm.entrypoints.anthropic.api_server", model, *vllm_serve_args], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + max_wait_seconds = max_wait_seconds or 240 + self._wait_for_server(url=self.url_for("health"), + timeout=max_wait_seconds) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + try: + self.proc.wait(8) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() + + def _wait_for_server(self, *, url: str, timeout: float): + # run health check + start = time.time() + while True: + try: + if requests.get(url).status_code == 200: + break + except Exception: + # this exception can only be raised by requests.get, + # which means the server is not ready yet. + # the stack trace is not useful, so we suppress it + # by using `raise from None`. + result = self.proc.poll() + if result is not None and result != 0: + raise RuntimeError("Server exited unexpectedly.") from None + + time.sleep(0.5) + if time.time() - start > timeout: + raise RuntimeError( + "Server failed to start in time.") from None + + @property + def url_root(self) -> str: + return f"http://{self.host}:{self.port}" + + def url_for(self, *parts: str) -> str: + return self.url_root + "/" + "/".join(parts) + + def get_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.Anthropic( + base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) + + def get_async_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.AsyncAnthropic( + base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs + ) + + + def _test_completion( client: openai.OpenAI, model: str, diff --git a/vllm/entrypoints/anthropic/__init__.py b/vllm/entrypoints/anthropic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/anthropic/api_server.py b/vllm/entrypoints/anthropic/api_server.py new file mode 100644 index 000000000000..7443f5bdc6e4 --- /dev/null +++ b/vllm/entrypoints/anthropic/api_server.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from: +# https://github.com/vllm/vllm/entrypoints/openai/api_server.py + +import asyncio +import signal +import tempfile +from argparse import Namespace +from http import HTTPStatus +from typing import Optional + +import uvloop +from fastapi import APIRouter, Depends, FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.anthropic.protocol import AnthropicErrorResponse, AnthropicMessagesRequest, \ + AnthropicMessagesResponse +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages +from vllm.entrypoints.chat_utils import (load_chat_template, + resolve_hf_chat_template, + resolve_mistral_chat_template) +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.api_server import validate_api_server_args, create_server_socket, load_log_config, \ + lifespan, build_async_engine_client, validate_json_request +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +from vllm.entrypoints.openai.protocol import ErrorResponse +from vllm.entrypoints.openai.serving_models import OpenAIServingModels, BaseModelPath, LoRAModulePath +# +# yapf: enable +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, + with_cancellation) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import MistralTokenizer +from vllm.utils import (FlexibleArgumentParser, + is_valid_ipv6_address, + set_ulimit) +from vllm.version import __version__ as VLLM_VERSION + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.anthropic.api_server') + +_running_tasks: set[asyncio.Task] = set() + +router = APIRouter() + + +def messages(request: Request) -> Optional[AnthropicServingMessages]: + return request.app.state.anthropic_serving_messages + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health", response_class=Response) +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.get("/ping", response_class=Response) +@router.post("/ping", response_class=Response) +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + +@router.post("/v1/messages", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": AnthropicErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": AnthropicErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": AnthropicErrorResponse + } + }) +@with_cancellation +@load_aware_call +async def create_messages(request: AnthropicMessagesRequest, + raw_request: Request): + handler = messages(raw_request) + if handler is None: + return messages(raw_request).create_error_response( + message="The model does not support Chat Completions API") + + generator = await handler.create_messages(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, AnthropicMessagesResponse): + return JSONResponse(content=generator.model_dump(exclude_none=True, exclude_unset=True)) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +async def init_app_state( + engine_client: EngineClient, + vllm_config: VllmConfig, + state: State, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + state.vllm_config = vllm_config + model_config = vllm_config.model_config + + default_mm_loras = (vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None else {}) + lora_modules = args.lora_modules + if default_mm_loras: + default_mm_lora_paths = [ + LoRAModulePath( + name=modality, + path=lora_path, + ) for modality, lora_path in default_mm_loras.items() + ] + if args.lora_modules is None: + lora_modules = default_mm_lora_paths + else: + lora_modules += default_mm_lora_paths + + resolved_chat_template = load_chat_template(args.chat_template) + if resolved_chat_template is not None: + # Get the tokenizer to check official template + tokenizer = await engine_client.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + # The warning is logged in resolve_mistral_chat_template. + resolved_chat_template = resolve_mistral_chat_template( + chat_template=resolved_chat_template) + else: + hf_chat_template = resolve_hf_chat_template( + tokenizer=tokenizer, + chat_template=None, + tools=None, + model_config=vllm_config.model_config, + ) + + if hf_chat_template != resolved_chat_template: + logger.warning( + "Using supplied chat template: %s\n" + "It is different from official chat template '%s'. " + "This discrepancy may lead to performance degradation.", + resolved_chat_template, args.model) + + state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=lora_modules, + ) + await state.openai_serving_models.init_static_loras() + state.anthropic_serving_messages = AnthropicServingMessages( + engine_client, + model_config, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) + + +def setup_server(args): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + """Run a single-worker API server.""" + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +def build_app(args: Namespace) -> FastAPI: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + return app + + +async def run_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + """Run a single API server worker.""" + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + # Load logging config for uvicorn if specified + log_config = load_log_config(args.log_config_file) + if log_config is not None: + uvicorn_kwargs['log_config'] = log_config + + async with build_async_engine_client( + args, + client_config=client_config, + ) as engine_client: + app = build_app(args) + + vllm_config = await engine_client.get_vllm_config() + await init_app_state(engine_client, vllm_config, app.state, args) + + logger.info("Starting vLLM API server %d on %s", server_index, + listen_address) + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/entrypoints/cli/main.py for CLI + # entrypoints. + cli_env_setup() + parser = FlexibleArgumentParser( + description="vLLM Anthropic-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py new file mode 100644 index 000000000000..c01ec6817b8c --- /dev/null +++ b/vllm/entrypoints/anthropic/protocol.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from +# https://github.com/sgl-project/sglang/blob/220962e46b087b5829137a67eab0205b4d51720b/python/sglang/srt/entrypoints/anthropic/protocol.py +"""Pydantic models for Anthropic API protocol""" + +import time +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field, field_validator, model_validator + + +class AnthropicError(BaseModel): + """Error structure for Anthropic API""" + type: str + message: str + + +class AnthropicErrorResponse(BaseModel): + """Error response structure for Anthropic API""" + type: Literal["error"] = "error" + error: AnthropicError + + +class AnthropicUsage(BaseModel): + """Token usage information""" + input_tokens: int + output_tokens: int + cache_creation_input_tokens: Optional[int] = None + cache_read_input_tokens: Optional[int] = None + + +class AnthropicContentBlock(BaseModel): + """Content block in message""" + type: Literal["text", "image", "tool_use", "tool_result"] + text: Optional[str] = None + # For image content + source: Optional[Dict[str, Any]] = None + # For tool use/result + id: Optional[str] = None + name: Optional[str] = None + input: Optional[Dict[str, Any]] = None + content: Optional[Union[str, List[Dict[str, Any]]]] = None + is_error: Optional[bool] = None + + +class AnthropicMessage(BaseModel): + """Message structure""" + role: Literal["user", "assistant"] + content: Union[str, List[AnthropicContentBlock]] + + +class AnthropicTool(BaseModel): + """Tool definition""" + name: str + description: Optional[str] = None + input_schema: Dict[str, Any] + + @field_validator("input_schema") + @classmethod + def validate_input_schema(cls, v): + if not isinstance(v, dict): + raise ValueError("input_schema must be a dictionary") + if "type" not in v: + v["type"] = "object" # Default to object type + return v + + +class AnthropicToolChoice(BaseModel): + """Tool Choice definition""" + type: Literal["auto", "any", "tool"] + name: Optional[str] = None + + +class AnthropicMessagesRequest(BaseModel): + """Anthropic Messages API request""" + model: str + messages: List[AnthropicMessage] + max_tokens: int + metadata: Optional[Dict[str, Any]] = None + stop_sequences: Optional[List[str]] = None + stream: Optional[bool] = False + system: Optional[str] = None + temperature: Optional[float] = None + tool_choice: Optional[AnthropicToolChoice] = None + tools: Optional[List[AnthropicTool]] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + + @field_validator("model") + @classmethod + def validate_model(cls, v): + if not v: + raise ValueError("Model is required") + return v + + @field_validator("max_tokens") + @classmethod + def validate_max_tokens(cls, v): + if v <= 0: + raise ValueError("max_tokens must be positive") + return v + + +class AnthropicDelta(BaseModel): + """Delta for streaming responses""" + type: Literal["text_delta", "input_json_delta"] = None + text: Optional[str] = None + partial_json: Optional[str] = None + + # Message delta + stop_reason: Optional[ + Literal["end_turn", "max_tokens", "stop_sequence", "tool_use", "pause_turn", "refusal"]] = None + stop_sequence: Optional[str] = None + usage: AnthropicUsage = None + + +class AnthropicStreamEvent(BaseModel): + """Streaming event""" + type: Literal[ + "message_start", "message_delta", "message_stop", + "content_block_start", "content_block_delta", "content_block_stop", + "ping", "error" + ] + message: Optional["AnthropicMessagesResponse"] = None + delta: Optional[AnthropicDelta] = None + content_block: Optional[AnthropicContentBlock] = None + index: Optional[int] = None + error: Optional[AnthropicError] = None + + +class AnthropicMessagesResponse(BaseModel): + """Anthropic Messages API response""" + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: List[AnthropicContentBlock] + model: str + stop_reason: Optional[Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]] = None + stop_sequence: Optional[str] = None + usage: AnthropicUsage = None + + def model_post_init(self, __context): + if not self.id: + self.id = f"msg_{int(time.time() * 1000)}" + + +# Forward reference resolution +AnthropicStreamEvent.model_rebuild() diff --git a/vllm/entrypoints/anthropic/serving_messages.py b/vllm/entrypoints/anthropic/serving_messages.py new file mode 100644 index 000000000000..05c15b6c3cdb --- /dev/null +++ b/vllm/entrypoints/anthropic/serving_messages.py @@ -0,0 +1,416 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from +# https://github.com/vllm/vllm/entrypoints/openai/serving_chat.py + +"""Anthropic Messages API serving handler""" +import json +import logging +import time +from typing import AsyncGenerator, List, Optional, Union + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.anthropic.protocol import ( + AnthropicContentBlock, + AnthropicDelta, + AnthropicMessagesRequest, + AnthropicMessagesResponse, + AnthropicStreamEvent, + AnthropicUsage, AnthropicError, +) +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ErrorResponse, ChatCompletionRequest, \ + ChatCompletionNamedToolChoiceParam, ChatCompletionToolsParam, ChatCompletionResponse, ChatCompletionStreamResponse, \ + StreamOptions +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_models import OpenAIServingModels + +logger = logging.getLogger(__name__) + + +class AnthropicServingMessages(OpenAIServingChat): + """Handler for Anthropic Messages API requests""" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + response_role: str, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + reasoning_parser: str = "", + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False, + enable_force_include_usage: bool = False, + ): + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + response_role=response_role, + request_logger=request_logger, + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + return_tokens_as_token_ids=return_tokens_as_token_ids, + reasoning_parser=reasoning_parser, + enable_auto_tools=enable_auto_tools, + tool_parser=tool_parser, + enable_prompt_tokens_details=enable_prompt_tokens_details, + enable_force_include_usage=enable_force_include_usage, + ) + self.stop_reason_map = { + "stop": "end_turn", + "length": "max_tokens", + "tool_calls": "tool_use", + } + + def _convert_anthropic_to_openai_request( + self, anthropic_request: AnthropicMessagesRequest + ) -> ChatCompletionRequest: + """Convert Anthropic message format to OpenAI format""" + openai_messages = [] + + # Add system message if provided + if anthropic_request.system: + openai_messages.append({"role": "system", "content": anthropic_request.system}) + + for msg in anthropic_request.messages: + openai_msg = {"role": msg.role} + + if isinstance(msg.content, str): + openai_msg["content"] = msg.content + else: + # Handle complex content blocks + content_parts = [] + tool_calls = [] + + for block in msg.content: + if block.type == "text" and block.text: + content_parts.append({"type": "text", "text": block.text}) + elif block.type == "image" and block.source: + content_parts.append({ + "type": "image_url", + "image_url": {"url": block.source.get("data", "")} + }) + elif block.type == "tool_use": + # Convert tool use to function call format + tool_call = { + "id": block.id or f"call_{int(time.time())}", + "type": "function", + "function": { + "name": block.name, + "arguments": json.dumps(block.input or {}) + } + } + tool_calls.append(tool_call) + elif block.type == "tool_result": + # For tool results, we need to create a tool message + # This will be handled separately as a tool response message + if msg.role == "user": + # Tool result from user should be converted to tool message + openai_messages.append({ + "role": "tool", + "tool_call_id": block.id, + "content": str(block.content) if block.content else "" + }) + else: + # Assistant tool result becomes regular text + content_parts.append({ + "type": "text", + "text": f"Tool result: {str(block.content) if block.content else ''}" + }) + + # Add tool calls to the message if any + if tool_calls: + openai_msg["tool_calls"] = tool_calls + + # Add content parts if any + if content_parts: + if len(content_parts) == 1 and content_parts[0]["type"] == "text": + openai_msg["content"] = content_parts[0]["text"] + else: + openai_msg["content"] = content_parts + elif not tool_calls: + # If no content and no tool calls, add empty content + openai_msg["content"] = "" + + openai_messages.append(openai_msg) + + req = ChatCompletionRequest( + model=anthropic_request.model, + messages=openai_messages, + max_tokens=anthropic_request.max_tokens, + max_completion_tokens=anthropic_request.max_tokens, + stop=anthropic_request.stop_sequences, + temperature=anthropic_request.temperature, + top_p=anthropic_request.top_p, + top_k=anthropic_request.top_k, + ) + + if anthropic_request.stream: + req.stream = anthropic_request.stream + req.stream_options = StreamOptions.validate({"include_usage": True}) + + if anthropic_request.tool_choice is None: + req.tool_choice = None + elif anthropic_request.tool_choice.type == "auto": + req.tool_choice = "auto" + elif anthropic_request.tool_choice.type == "any": + req.tool_choice = "required" + elif anthropic_request.tool_choice.type == "tool": + req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate({ + "type": "function", + "function": { + "name": anthropic_request.tool_choice.get("name") + } + }) + + tools = [] + if anthropic_request.tools is None: + return req + for tool in anthropic_request.tools: + tools.append( + ChatCompletionToolsParam.model_validate({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema + } + }) + ) + if req.tool_choice is None: + req.tool_choice = "auto" + req.tools = tools + return req + + async def create_messages( + self, + request: AnthropicMessagesRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], AnthropicMessagesResponse, + ErrorResponse]: + """ + Messages API similar to Anthropic's API. + + See https://docs.anthropic.com/en/api/messages + for the API specification. This API mimics the Anthropic messages API. + """ + chat_req = self._convert_anthropic_to_openai_request(request) + generator = await self.create_chat_completion(chat_req, raw_request) + + if isinstance(generator, ErrorResponse): + return generator + + elif isinstance(generator, ChatCompletionResponse): + return self.messages_full_converter(generator) + + return self.message_stream_converter(generator) + + def messages_full_converter( + self, + generator: ChatCompletionResponse, + ) -> AnthropicMessagesResponse: + result = AnthropicMessagesResponse( + id=generator.id, + content=[], + model=generator.model, + usage=AnthropicUsage( + input_tokens=generator.usage.prompt_tokens, + output_tokens=generator.usage.completion_tokens, + ), + ) + if generator.choices[0].finish_reason == "stop": + result.stop_reason = "end_turn" + elif generator.choices[0].finish_reason == "length": + result.stop_reason = "max_tokens" + elif generator.choices[0].finish_reason == "tool_calls": + result.stop_reason = "tool_use" + + content: List[AnthropicContentBlock] = [ + AnthropicContentBlock( + type="text", + text=generator.choices[0].message.content + ) + ] + + for tool_call in generator.choices[0].message.tool_calls: + anthropic_tool_call = AnthropicContentBlock( + type="tool_use", + id=tool_call.id, + name=tool_call.function.name, + input=json.loads(tool_call.function.arguments) + ) + content += [anthropic_tool_call] + + result.content = content + + return result + + async def message_stream_converter( + self, + generator: AsyncGenerator[str, None], + ) -> AsyncGenerator[str, None]: + try: + first_item = True + finish_reason = None + content_block_index = 0 + content_block_started = False + + async for item in generator: + if item.startswith("data:"): + data_str = item[5:].strip().rstrip("\n") + if data_str == "[DONE]": + stop_message = AnthropicStreamEvent( + type="message_stop", + ) + data = stop_message.model_dump_json(exclude_unset=True, exclude_none=True) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + else: + origin_chunk = ChatCompletionStreamResponse.model_validate_json(data_str) + + if first_item: + chunk = AnthropicStreamEvent( + type="message_start", + message=AnthropicMessagesResponse( + id=origin_chunk.id, + content=[], + model=origin_chunk.model, + ) + ) + first_item = False + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + continue + + # last chunk including usage info + if len(origin_chunk.choices) == 0: + if content_block_started: + stop_chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_stop", + ) + data = stop_chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + chunk = AnthropicStreamEvent( + type="message_delta", + delta=AnthropicDelta( + stop_reason=self.stop_reason_map.get(finish_reason, "end_turn"), + usage=AnthropicUsage( + input_tokens=origin_chunk.usage.prompt_tokens or 0, + output_tokens=origin_chunk.usage.completion_tokens or 0 + ) + ) + ) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + continue + + # content + if origin_chunk.choices[0].delta.content is not None: + if not content_block_started: + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_start", + content_block=AnthropicContentBlock( + type="text", + text="" + ) + ) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + content_block_started = True + + if origin_chunk.choices[0].delta.content == "": + continue + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_delta", + delta=AnthropicDelta( + type="text_delta", + text=origin_chunk.choices[0].delta.content + ) + ) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + continue + + # tool calls + elif len(origin_chunk.choices[0].delta.tool_calls) > 0: + tool_call = origin_chunk.choices[0].delta.tool_calls[0] + if tool_call.id is not None: + if content_block_started: + stop_chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_stop", + ) + data = stop_chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + content_block_started = False + content_block_index += 1 + + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_start", + content_block=AnthropicContentBlock( + type="tool_use", + id=tool_call.id, + name=tool_call.function.name if tool_call.function else None, + input={}, + ) + ) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + content_block_started = True + + else: + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_delta", + delta=AnthropicDelta( + type="input_json_delta", + partial_json=tool_call.function.arguments + ) + ) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + continue + + if origin_chunk.choices[0].finish_reason is not None: + finish_reason = origin_chunk.choices[0].finish_reason + continue + + else: + error_response = AnthropicStreamEvent( + type="error", + error=AnthropicError( + type="internal_error", + message="Invalid data format received" + ) + ) + data = error_response.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + logger.exception("Error in message stream converter.") + error_response = AnthropicStreamEvent( + type="error", + error=AnthropicError( + type="internal_error", + message=str(e) + ) + ) + data = error_response.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n"