3
3
# Adapted from
4
4
# https://github.com/sgl-project/sglang/blob/220962e46b087b5829137a67eab0205b4d51720b/python/sglang/srt/entrypoints/anthropic/protocol.py
5
5
"""Pydantic models for Anthropic API protocol"""
6
- import json
7
- import time
8
- from typing import Any , Dict , List , Literal , Optional , Union , Annotated
9
- from pydantic import BaseModel , Field , field_validator , model_validator
10
6
11
- from anthropic .types .message_param import MessageParam as AnthropicMessageParam
12
- from vllm .sampling_params import BeamSearchParams , SamplingParams , GuidedDecodingParams , RequestOutputKind
13
- from vllm .utils import random_uuid
14
- import torch
7
+ import time
8
+ from typing import Any , Dict , List , Literal , Optional , Union
15
9
16
- _LONG_INFO = torch . iinfo ( torch . long )
10
+ from pydantic import BaseModel , Field , field_validator , model_validator
17
11
18
12
19
13
class AnthropicError (BaseModel ):
@@ -75,13 +69,13 @@ def validate_input_schema(cls, v):
75
69
class AnthropicToolChoice (BaseModel ):
76
70
"""Tool Choice definition"""
77
71
type : Literal ["auto" , "any" , "tool" ]
78
- name : Optional [str ]
72
+ name : Optional [str ] = None
79
73
80
74
81
75
class AnthropicMessagesRequest (BaseModel ):
82
76
"""Anthropic Messages API request"""
83
77
model : str
84
- messages : List [AnthropicMessageParam ]
78
+ messages : List [AnthropicMessage ]
85
79
max_tokens : int
86
80
metadata : Optional [Dict [str , Any ]] = None
87
81
stop_sequences : Optional [List [str ]] = None
@@ -90,131 +84,8 @@ class AnthropicMessagesRequest(BaseModel):
90
84
temperature : Optional [float ] = None
91
85
tool_choice : Optional [AnthropicToolChoice ] = None
92
86
tools : Optional [List [AnthropicTool ]] = None
93
- top_p : Optional [float ] = None
94
-
95
- # --8<-- [start:chat-completion-sampling-params]
96
- seed : Optional [int ] = Field (None , ge = _LONG_INFO .min , le = _LONG_INFO .max )
97
- stop : Optional [Union [str , list [str ]]] = []
98
- best_of : Optional [int ] = None
99
- use_beam_search : bool = False
100
87
top_k : Optional [int ] = None
101
- min_p : Optional [float ] = None
102
- frequency_penalty : Optional [float ] = 0.0
103
- presence_penalty : Optional [float ] = 0.0
104
- repetition_penalty : Optional [float ] = None
105
- length_penalty : float = 1.0
106
- stop_token_ids : Optional [list [int ]] = []
107
- include_stop_str_in_output : bool = False
108
- ignore_eos : bool = False
109
- min_tokens : int = 0
110
- skip_special_tokens : bool = True
111
- spaces_between_special_tokens : bool = True
112
- truncate_prompt_tokens : Optional [Annotated [int , Field (ge = 1 )]] = None
113
- prompt_logprobs : Optional [int ] = None
114
- allowed_token_ids : Optional [list [int ]] = None
115
- bad_words : list [str ] = Field (default_factory = list )
116
-
117
- # --8<-- [end:chat-completion-sampling-params]
118
-
119
- chat_template : Optional [str ] = Field (
120
- default = None ,
121
- description = (
122
- "A Jinja template to use for this conversion. "
123
- "As of transformers v4.44, default chat template is no longer "
124
- "allowed, so you must provide a chat template if the tokenizer "
125
- "does not define one." ),
126
- )
127
- chat_template_kwargs : Optional [dict [str , Any ]] = Field (
128
- default = None ,
129
- description = (
130
- "Additional keyword args to pass to the template renderer. "
131
- "Will be accessible by the chat template." ),
132
- )
133
- mm_processor_kwargs : Optional [dict [str , Any ]] = Field (
134
- default = None ,
135
- description = ("Additional kwargs to pass to the HF processor." ),
136
- )
137
- priority : int = Field (
138
- default = 0 ,
139
- description = (
140
- "The priority of the request (lower means earlier handling; "
141
- "default: 0). Any priority other than 0 will raise an error "
142
- "if the served model does not use priority scheduling." ),
143
- )
144
- request_id : str = Field (
145
- default_factory = lambda : f"{ random_uuid ()} " ,
146
- description = (
147
- "The request_id related to this request. If the caller does "
148
- "not set it, a random_uuid will be generated. This id is used "
149
- "through out the inference process and return in response." ),
150
- )
151
-
152
- _DEFAULT_SAMPLING_PARAMS : dict = {
153
- "repetition_penalty" : 1.0 ,
154
- "temperature" : 1.0 ,
155
- "top_p" : 1.0 ,
156
- "top_k" : 0 ,
157
- "min_p" : 0.0 ,
158
- }
159
-
160
- def to_beam_search_params (
161
- self , max_tokens : int ,
162
- default_sampling_params : dict ) -> BeamSearchParams :
163
-
164
- n = self .n if self .n is not None else 1
165
- if (temperature := self .temperature ) is None :
166
- temperature = default_sampling_params .get (
167
- "temperature" , self ._DEFAULT_SAMPLING_PARAMS ["temperature" ])
168
-
169
- return BeamSearchParams (
170
- beam_width = n ,
171
- max_tokens = max_tokens ,
172
- ignore_eos = self .ignore_eos ,
173
- temperature = temperature ,
174
- length_penalty = self .length_penalty ,
175
- include_stop_str_in_output = self .include_stop_str_in_output ,
176
- )
177
-
178
- def to_sampling_params (
179
- self ,
180
- max_tokens : int ,
181
- default_sampling_params : dict ,
182
- ) -> SamplingParams :
183
-
184
- # Default parameters
185
- if (repetition_penalty := self .repetition_penalty ) is None :
186
- repetition_penalty = default_sampling_params .get (
187
- "repetition_penalty" ,
188
- self ._DEFAULT_SAMPLING_PARAMS ["repetition_penalty" ],
189
- )
190
- if (temperature := self .temperature ) is None :
191
- temperature = default_sampling_params .get (
192
- "temperature" , self ._DEFAULT_SAMPLING_PARAMS ["temperature" ])
193
- if (top_p := self .top_p ) is None :
194
- top_p = default_sampling_params .get (
195
- "top_p" , self ._DEFAULT_SAMPLING_PARAMS ["top_p" ])
196
- if (top_k := self .top_k ) is None :
197
- top_k = default_sampling_params .get (
198
- "top_k" , self ._DEFAULT_SAMPLING_PARAMS ["top_k" ])
199
- if (min_p := self .min_p ) is None :
200
- min_p = default_sampling_params .get (
201
- "min_p" , self ._DEFAULT_SAMPLING_PARAMS ["min_p" ])
202
-
203
- return SamplingParams .from_optional (
204
- n = 1 ,
205
- best_of = self .best_of ,
206
- presence_penalty = self .presence_penalty ,
207
- frequency_penalty = self .frequency_penalty ,
208
- repetition_penalty = repetition_penalty ,
209
- temperature = temperature ,
210
- top_p = top_p ,
211
- top_k = top_k ,
212
- min_p = min_p ,
213
- seed = self .seed ,
214
- stop = self .stop ,
215
- stop_token_ids = self .stop_token_ids ,
216
- max_tokens = max_tokens ,
217
- )
88
+ top_p : Optional [float ] = None
218
89
219
90
@field_validator ("model" )
220
91
@classmethod
@@ -233,10 +104,16 @@ def validate_max_tokens(cls, v):
233
104
234
105
class AnthropicDelta (BaseModel ):
235
106
"""Delta for streaming responses"""
236
- type : Literal ["text_delta" , "input_json_delta" ]
107
+ type : Literal ["text_delta" , "input_json_delta" ] = None
237
108
text : Optional [str ] = None
238
109
partial_json : Optional [str ] = None
239
110
111
+ # Message delta
112
+ stop_reason : Optional [
113
+ Literal ["end_turn" , "max_tokens" , "stop_sequence" , "tool_use" , "pause_turn" , "refusal" ]] = None
114
+ stop_sequence : Optional [str ] = None
115
+ usage : AnthropicUsage = None
116
+
240
117
241
118
class AnthropicStreamEvent (BaseModel ):
242
119
"""Streaming event"""
@@ -261,7 +138,7 @@ class AnthropicMessagesResponse(BaseModel):
261
138
model : str
262
139
stop_reason : Optional [Literal ["end_turn" , "max_tokens" , "stop_sequence" , "tool_use" ]] = None
263
140
stop_sequence : Optional [str ] = None
264
- usage : AnthropicUsage
141
+ usage : AnthropicUsage = None
265
142
266
143
def model_post_init (self , __context ):
267
144
if not self .id :
0 commit comments