Skip to content

Commit 20f42bd

Browse files
committed
support anthropic endpoint
Signed-off-by: liuli <[email protected]>
1 parent 845420a commit 20f42bd

File tree

4 files changed

+896
-0
lines changed

4 files changed

+896
-0
lines changed

vllm/entrypoints/anthropic/__init__.py

Whitespace-only changes.
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Adapted from:
4+
# https://github.com/vllm/vllm/entrypoints/openai/api_server.py
5+
6+
import asyncio
7+
import signal
8+
import tempfile
9+
from argparse import Namespace
10+
from http import HTTPStatus
11+
from typing import Optional
12+
13+
import uvloop
14+
from fastapi import APIRouter, Depends, FastAPI, Request
15+
from fastapi.middleware.cors import CORSMiddleware
16+
from fastapi.responses import JSONResponse, Response, StreamingResponse
17+
from starlette.datastructures import State
18+
19+
import vllm.envs as envs
20+
from vllm.config import VllmConfig
21+
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
22+
from vllm.engine.protocol import EngineClient
23+
from vllm.entrypoints.anthropic.protocol import AnthropicErrorResponse, AnthropicMessagesRequest, \
24+
AnthropicMessagesResponse
25+
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
26+
from vllm.entrypoints.chat_utils import (load_chat_template,
27+
resolve_hf_chat_template,
28+
resolve_mistral_chat_template)
29+
from vllm.entrypoints.launcher import serve_http
30+
from vllm.entrypoints.logger import RequestLogger
31+
from vllm.entrypoints.openai.api_server import validate_api_server_args, create_server_socket, load_log_config, \
32+
lifespan, build_async_engine_client, validate_json_request
33+
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
34+
validate_parsed_serve_args)
35+
from vllm.entrypoints.openai.protocol import ErrorResponse
36+
from vllm.entrypoints.openai.serving_models import OpenAIServingModels, BaseModelPath, LoRAModulePath
37+
#
38+
# yapf: enable
39+
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
40+
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
41+
with_cancellation)
42+
from vllm.logger import init_logger
43+
from vllm.transformers_utils.tokenizer import MistralTokenizer
44+
from vllm.utils import (FlexibleArgumentParser,
45+
is_valid_ipv6_address,
46+
set_ulimit)
47+
from vllm.version import __version__ as VLLM_VERSION
48+
49+
prometheus_multiproc_dir: tempfile.TemporaryDirectory
50+
51+
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
52+
logger = init_logger('vllm.entrypoints.anthropic.api_server')
53+
54+
_running_tasks: set[asyncio.Task] = set()
55+
56+
router = APIRouter()
57+
58+
59+
def messages(request: Request) -> Optional[AnthropicServingMessages]:
60+
return request.app.state.anthropic_serving_messages
61+
62+
63+
def engine_client(request: Request) -> EngineClient:
64+
return request.app.state.engine_client
65+
66+
67+
@router.get("/health", response_class=Response)
68+
async def health(raw_request: Request) -> Response:
69+
"""Health check."""
70+
await engine_client(raw_request).check_health()
71+
return Response(status_code=200)
72+
73+
74+
@router.get("/ping", response_class=Response)
75+
@router.post("/ping", response_class=Response)
76+
async def ping(raw_request: Request) -> Response:
77+
"""Ping check. Endpoint required for SageMaker"""
78+
return await health(raw_request)
79+
80+
81+
@router.post("/v1/messages",
82+
dependencies=[Depends(validate_json_request)],
83+
responses={
84+
HTTPStatus.OK.value: {
85+
"content": {
86+
"text/event-stream": {}
87+
}
88+
},
89+
HTTPStatus.BAD_REQUEST.value: {
90+
"model": AnthropicErrorResponse
91+
},
92+
HTTPStatus.NOT_FOUND.value: {
93+
"model": AnthropicErrorResponse
94+
},
95+
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
96+
"model": AnthropicErrorResponse
97+
}
98+
})
99+
@with_cancellation
100+
@load_aware_call
101+
async def create_messages(request: AnthropicMessagesRequest,
102+
raw_request: Request):
103+
handler = messages(raw_request)
104+
if handler is None:
105+
return messages(raw_request).create_error_response(
106+
message="The model does not support Chat Completions API")
107+
108+
generator = await handler.create_messages(request, raw_request)
109+
110+
if isinstance(generator, ErrorResponse):
111+
return JSONResponse(content=generator.model_dump(),
112+
status_code=generator.code)
113+
114+
elif isinstance(generator, AnthropicMessagesResponse):
115+
return JSONResponse(content=generator.model_dump(exclude_none=True, exclude_unset=True))
116+
117+
return StreamingResponse(content=generator, media_type="text/event-stream")
118+
119+
120+
async def init_app_state(
121+
engine_client: EngineClient,
122+
vllm_config: VllmConfig,
123+
state: State,
124+
args: Namespace,
125+
) -> None:
126+
if args.served_model_name is not None:
127+
served_model_names = args.served_model_name
128+
else:
129+
served_model_names = [args.model]
130+
131+
if args.disable_log_requests:
132+
request_logger = None
133+
else:
134+
request_logger = RequestLogger(max_log_len=args.max_log_len)
135+
136+
base_model_paths = [
137+
BaseModelPath(name=name, model_path=args.model)
138+
for name in served_model_names
139+
]
140+
141+
state.engine_client = engine_client
142+
state.log_stats = not args.disable_log_stats
143+
state.vllm_config = vllm_config
144+
model_config = vllm_config.model_config
145+
146+
default_mm_loras = (vllm_config.lora_config.default_mm_loras
147+
if vllm_config.lora_config is not None else {})
148+
lora_modules = args.lora_modules
149+
if default_mm_loras:
150+
default_mm_lora_paths = [
151+
LoRAModulePath(
152+
name=modality,
153+
path=lora_path,
154+
) for modality, lora_path in default_mm_loras.items()
155+
]
156+
if args.lora_modules is None:
157+
lora_modules = default_mm_lora_paths
158+
else:
159+
lora_modules += default_mm_lora_paths
160+
161+
resolved_chat_template = load_chat_template(args.chat_template)
162+
if resolved_chat_template is not None:
163+
# Get the tokenizer to check official template
164+
tokenizer = await engine_client.get_tokenizer()
165+
166+
if isinstance(tokenizer, MistralTokenizer):
167+
# The warning is logged in resolve_mistral_chat_template.
168+
resolved_chat_template = resolve_mistral_chat_template(
169+
chat_template=resolved_chat_template)
170+
else:
171+
hf_chat_template = resolve_hf_chat_template(
172+
tokenizer=tokenizer,
173+
chat_template=None,
174+
tools=None,
175+
model_config=vllm_config.model_config,
176+
)
177+
178+
if hf_chat_template != resolved_chat_template:
179+
logger.warning(
180+
"Using supplied chat template: %s\n"
181+
"It is different from official chat template '%s'. "
182+
"This discrepancy may lead to performance degradation.",
183+
resolved_chat_template, args.model)
184+
185+
state.openai_serving_models = OpenAIServingModels(
186+
engine_client=engine_client,
187+
model_config=model_config,
188+
base_model_paths=base_model_paths,
189+
lora_modules=lora_modules,
190+
)
191+
await state.openai_serving_models.init_static_loras()
192+
state.anthropic_serving_messages = AnthropicServingMessages(
193+
engine_client,
194+
model_config,
195+
state.openai_serving_models,
196+
args.response_role,
197+
request_logger=request_logger,
198+
chat_template=resolved_chat_template,
199+
chat_template_content_format=args.chat_template_content_format,
200+
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
201+
enable_auto_tools=args.enable_auto_tool_choice,
202+
tool_parser=args.tool_call_parser,
203+
reasoning_parser=args.reasoning_parser,
204+
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
205+
enable_force_include_usage=args.enable_force_include_usage,
206+
)
207+
208+
209+
def setup_server(args):
210+
"""Validate API server args, set up signal handler, create socket
211+
ready to serve."""
212+
213+
logger.info("vLLM API server version %s", VLLM_VERSION)
214+
215+
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
216+
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
217+
218+
validate_api_server_args(args)
219+
220+
# workaround to make sure that we bind the port before the engine is set up.
221+
# This avoids race conditions with ray.
222+
# see https://github.com/vllm-project/vllm/issues/8204
223+
sock_addr = (args.host or "", args.port)
224+
sock = create_server_socket(sock_addr)
225+
226+
# workaround to avoid footguns where uvicorn drops requests with too
227+
# many concurrent requests active
228+
set_ulimit()
229+
230+
def signal_handler(*_) -> None:
231+
# Interrupt server on sigterm while initializing
232+
raise KeyboardInterrupt("terminated")
233+
234+
signal.signal(signal.SIGTERM, signal_handler)
235+
236+
addr, port = sock_addr
237+
is_ssl = args.ssl_keyfile and args.ssl_certfile
238+
host_part = f"[{addr}]" if is_valid_ipv6_address(
239+
addr) else addr or "0.0.0.0"
240+
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
241+
242+
return listen_address, sock
243+
244+
245+
async def run_server(args, **uvicorn_kwargs) -> None:
246+
"""Run a single-worker API server."""
247+
listen_address, sock = setup_server(args)
248+
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
249+
250+
251+
def build_app(args: Namespace) -> FastAPI:
252+
app = FastAPI(lifespan=lifespan)
253+
app.include_router(router)
254+
app.root_path = args.root_path
255+
256+
app.add_middleware(
257+
CORSMiddleware,
258+
allow_origins=args.allowed_origins,
259+
allow_credentials=args.allow_credentials,
260+
allow_methods=args.allowed_methods,
261+
allow_headers=args.allowed_headers,
262+
)
263+
264+
return app
265+
266+
267+
async def run_server_worker(listen_address,
268+
sock,
269+
args,
270+
client_config=None,
271+
**uvicorn_kwargs) -> None:
272+
"""Run a single API server worker."""
273+
274+
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
275+
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
276+
277+
server_index = client_config.get("client_index", 0) if client_config else 0
278+
279+
# Load logging config for uvicorn if specified
280+
log_config = load_log_config(args.log_config_file)
281+
if log_config is not None:
282+
uvicorn_kwargs['log_config'] = log_config
283+
284+
async with build_async_engine_client(
285+
args,
286+
client_config=client_config,
287+
) as engine_client:
288+
app = build_app(args)
289+
290+
vllm_config = await engine_client.get_vllm_config()
291+
await init_app_state(engine_client, vllm_config, app.state, args)
292+
293+
logger.info("Starting vLLM API server %d on %s", server_index,
294+
listen_address)
295+
shutdown_task = await serve_http(
296+
app,
297+
sock=sock,
298+
enable_ssl_refresh=args.enable_ssl_refresh,
299+
host=args.host,
300+
port=args.port,
301+
log_level=args.uvicorn_log_level,
302+
# NOTE: When the 'disable_uvicorn_access_log' value is True,
303+
# no access log will be output.
304+
access_log=not args.disable_uvicorn_access_log,
305+
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
306+
ssl_keyfile=args.ssl_keyfile,
307+
ssl_certfile=args.ssl_certfile,
308+
ssl_ca_certs=args.ssl_ca_certs,
309+
ssl_cert_reqs=args.ssl_cert_reqs,
310+
**uvicorn_kwargs,
311+
)
312+
313+
# NB: Await server shutdown only after the backend context is exited
314+
try:
315+
await shutdown_task
316+
finally:
317+
sock.close()
318+
319+
320+
if __name__ == "__main__":
321+
# NOTE(simon):
322+
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
323+
# entrypoints.
324+
cli_env_setup()
325+
parser = FlexibleArgumentParser(
326+
description="vLLM Anthropic-Compatible RESTful API server.")
327+
parser = make_arg_parser(parser)
328+
args = parser.parse_args()
329+
validate_parsed_serve_args(args)
330+
331+
uvloop.run(run_server(args))

0 commit comments

Comments
 (0)