Skip to content

Commit 5a3fa47

Browse files
committed
support anthropic endpoint
Signed-off-by: liuli <[email protected]>
1 parent 68be681 commit 5a3fa47

File tree

6 files changed

+461
-537
lines changed

6 files changed

+461
-537
lines changed

vllm/entrypoints/anthropic/api_server.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
# yapf: enable
5656
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
5757
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
58-
log_non_default_args, with_cancellation)
58+
with_cancellation)
5959
from vllm.logger import init_logger
6060
from vllm.reasoning import ReasoningParserManager
6161
from vllm.transformers_utils.config import (
@@ -133,7 +133,7 @@ async def create_messages(request: AnthropicMessagesRequest,
133133
status_code=generator.code)
134134

135135
elif isinstance(generator, AnthropicMessagesResponse):
136-
return JSONResponse(content=generator.model_dump())
136+
return JSONResponse(content=generator.model_dump(exclude_none=True, exclude_unset=True))
137137

138138
return StreamingResponse(content=generator, media_type="text/event-stream")
139139

@@ -232,7 +232,6 @@ def setup_server(args):
232232
ready to serve."""
233233

234234
logger.info("vLLM API server version %s", VLLM_VERSION)
235-
log_non_default_args(args)
236235

237236
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
238237
ToolParserManager.import_tool_parser(args.tool_parser_plugin)

vllm/entrypoints/anthropic/protocol.py

Lines changed: 14 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,11 @@
33
# Adapted from
44
# https://github.com/sgl-project/sglang/blob/220962e46b087b5829137a67eab0205b4d51720b/python/sglang/srt/entrypoints/anthropic/protocol.py
55
"""Pydantic models for Anthropic API protocol"""
6-
import json
7-
import time
8-
from typing import Any, Dict, List, Literal, Optional, Union, Annotated
9-
from pydantic import BaseModel, Field, field_validator, model_validator
106

11-
from anthropic.types.message_param import MessageParam as AnthropicMessageParam
12-
from vllm.sampling_params import BeamSearchParams, SamplingParams, GuidedDecodingParams, RequestOutputKind
13-
from vllm.utils import random_uuid
14-
import torch
7+
import time
8+
from typing import Any, Dict, List, Literal, Optional, Union
159

16-
_LONG_INFO = torch.iinfo(torch.long)
10+
from pydantic import BaseModel, Field, field_validator, model_validator
1711

1812

1913
class AnthropicError(BaseModel):
@@ -75,13 +69,13 @@ def validate_input_schema(cls, v):
7569
class AnthropicToolChoice(BaseModel):
7670
"""Tool Choice definition"""
7771
type: Literal["auto", "any", "tool"]
78-
name: Optional[str]
72+
name: Optional[str] = None
7973

8074

8175
class AnthropicMessagesRequest(BaseModel):
8276
"""Anthropic Messages API request"""
8377
model: str
84-
messages: List[AnthropicMessageParam]
78+
messages: List[AnthropicMessage]
8579
max_tokens: int
8680
metadata: Optional[Dict[str, Any]] = None
8781
stop_sequences: Optional[List[str]] = None
@@ -90,131 +84,8 @@ class AnthropicMessagesRequest(BaseModel):
9084
temperature: Optional[float] = None
9185
tool_choice: Optional[AnthropicToolChoice] = None
9286
tools: Optional[List[AnthropicTool]] = None
93-
top_p: Optional[float] = None
94-
95-
# --8<-- [start:chat-completion-sampling-params]
96-
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
97-
stop: Optional[Union[str, list[str]]] = []
98-
best_of: Optional[int] = None
99-
use_beam_search: bool = False
10087
top_k: Optional[int] = None
101-
min_p: Optional[float] = None
102-
frequency_penalty: Optional[float] = 0.0
103-
presence_penalty: Optional[float] = 0.0
104-
repetition_penalty: Optional[float] = None
105-
length_penalty: float = 1.0
106-
stop_token_ids: Optional[list[int]] = []
107-
include_stop_str_in_output: bool = False
108-
ignore_eos: bool = False
109-
min_tokens: int = 0
110-
skip_special_tokens: bool = True
111-
spaces_between_special_tokens: bool = True
112-
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
113-
prompt_logprobs: Optional[int] = None
114-
allowed_token_ids: Optional[list[int]] = None
115-
bad_words: list[str] = Field(default_factory=list)
116-
117-
# --8<-- [end:chat-completion-sampling-params]
118-
119-
chat_template: Optional[str] = Field(
120-
default=None,
121-
description=(
122-
"A Jinja template to use for this conversion. "
123-
"As of transformers v4.44, default chat template is no longer "
124-
"allowed, so you must provide a chat template if the tokenizer "
125-
"does not define one."),
126-
)
127-
chat_template_kwargs: Optional[dict[str, Any]] = Field(
128-
default=None,
129-
description=(
130-
"Additional keyword args to pass to the template renderer. "
131-
"Will be accessible by the chat template."),
132-
)
133-
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
134-
default=None,
135-
description=("Additional kwargs to pass to the HF processor."),
136-
)
137-
priority: int = Field(
138-
default=0,
139-
description=(
140-
"The priority of the request (lower means earlier handling; "
141-
"default: 0). Any priority other than 0 will raise an error "
142-
"if the served model does not use priority scheduling."),
143-
)
144-
request_id: str = Field(
145-
default_factory=lambda: f"{random_uuid()}",
146-
description=(
147-
"The request_id related to this request. If the caller does "
148-
"not set it, a random_uuid will be generated. This id is used "
149-
"through out the inference process and return in response."),
150-
)
151-
152-
_DEFAULT_SAMPLING_PARAMS: dict = {
153-
"repetition_penalty": 1.0,
154-
"temperature": 1.0,
155-
"top_p": 1.0,
156-
"top_k": 0,
157-
"min_p": 0.0,
158-
}
159-
160-
def to_beam_search_params(
161-
self, max_tokens: int,
162-
default_sampling_params: dict) -> BeamSearchParams:
163-
164-
n = self.n if self.n is not None else 1
165-
if (temperature := self.temperature) is None:
166-
temperature = default_sampling_params.get(
167-
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
168-
169-
return BeamSearchParams(
170-
beam_width=n,
171-
max_tokens=max_tokens,
172-
ignore_eos=self.ignore_eos,
173-
temperature=temperature,
174-
length_penalty=self.length_penalty,
175-
include_stop_str_in_output=self.include_stop_str_in_output,
176-
)
177-
178-
def to_sampling_params(
179-
self,
180-
max_tokens: int,
181-
default_sampling_params: dict,
182-
) -> SamplingParams:
183-
184-
# Default parameters
185-
if (repetition_penalty := self.repetition_penalty) is None:
186-
repetition_penalty = default_sampling_params.get(
187-
"repetition_penalty",
188-
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
189-
)
190-
if (temperature := self.temperature) is None:
191-
temperature = default_sampling_params.get(
192-
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
193-
if (top_p := self.top_p) is None:
194-
top_p = default_sampling_params.get(
195-
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
196-
if (top_k := self.top_k) is None:
197-
top_k = default_sampling_params.get(
198-
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
199-
if (min_p := self.min_p) is None:
200-
min_p = default_sampling_params.get(
201-
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
202-
203-
return SamplingParams.from_optional(
204-
n=1,
205-
best_of=self.best_of,
206-
presence_penalty=self.presence_penalty,
207-
frequency_penalty=self.frequency_penalty,
208-
repetition_penalty=repetition_penalty,
209-
temperature=temperature,
210-
top_p=top_p,
211-
top_k=top_k,
212-
min_p=min_p,
213-
seed=self.seed,
214-
stop=self.stop,
215-
stop_token_ids=self.stop_token_ids,
216-
max_tokens=max_tokens,
217-
)
88+
top_p: Optional[float] = None
21889

21990
@field_validator("model")
22091
@classmethod
@@ -233,10 +104,16 @@ def validate_max_tokens(cls, v):
233104

234105
class AnthropicDelta(BaseModel):
235106
"""Delta for streaming responses"""
236-
type: Literal["text_delta", "input_json_delta"]
107+
type: Literal["text_delta", "input_json_delta"] = None
237108
text: Optional[str] = None
238109
partial_json: Optional[str] = None
239110

111+
# Message delta
112+
stop_reason: Optional[
113+
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use", "pause_turn", "refusal"]] = None
114+
stop_sequence: Optional[str] = None
115+
usage: AnthropicUsage = None
116+
240117

241118
class AnthropicStreamEvent(BaseModel):
242119
"""Streaming event"""
@@ -261,7 +138,7 @@ class AnthropicMessagesResponse(BaseModel):
261138
model: str
262139
stop_reason: Optional[Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]] = None
263140
stop_sequence: Optional[str] = None
264-
usage: AnthropicUsage
141+
usage: AnthropicUsage = None
265142

266143
def model_post_init(self, __context):
267144
if not self.id:

0 commit comments

Comments
 (0)