1111
1212import uvicorn
1313from fastapi import FastAPI , HTTPException , Request , Header
14+ from fastapi .exceptions import RequestValidationError
1415from fastapi .responses import StreamingResponse , JSONResponse
1516from fastapi .middleware .cors import CORSMiddleware
16- from pydantic import BaseModel , Field
17+ from pydantic import BaseModel , Field , AliasChoices , ConfigDict
1718from openai import AsyncOpenAI , APIError , RateLimitError , AuthenticationError
1819
1920# --- [System Configuration] ---
@@ -100,12 +101,24 @@ class Parameters(BaseModel):
100101 top_p : Optional [float ] = 0.8
101102 top_k : Optional [int ] = None
102103 seed : Optional [int ] = 1234
103- max_tokens : Optional [int ] = None
104+ max_tokens : Optional [int ] = Field (
105+ None ,
106+ validation_alias = AliasChoices ("max_tokens" , "max_length" )
107+ )
108+ frequency_penalty : Optional [float ] = 0.0
109+ presence_penalty : Optional [float ] = 0.0
110+ repetition_penalty : Optional [float ] = 1.0
111+
112+ # OpenAI原生格式
104113 stop : Optional [Union [str , List [str ]]] = None
114+ # DashScope兼容格式
115+ stop_words : Optional [List [Dict [str , Any ]]] = None
105116 enable_thinking : bool = False
106117 thinking_budget : Optional [int ] = None
107118 tools : Optional [List [Dict [str , Any ]]] = None
108119 tool_choice : Optional [Union [str , Dict [str , Any ]]] = None
120+ # 显式开启从属性名读取
121+ model_config = ConfigDict (populate_by_name = True , protected_namespaces = ())
109122
110123class GenerationRequest (BaseModel ):
111124 model : str
@@ -179,15 +192,85 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
179192 "stream" : params .incremental_output or params .enable_thinking ,
180193 }
181194
195+ if params .frequency_penalty is not None :
196+ openai_params ["frequency_penalty" ] = params .frequency_penalty
197+
198+ if params .presence_penalty is not None :
199+ openai_params ["presence_penalty" ] = params .presence_penalty
200+
201+ extra_body = {}
202+ # 不是 OpenAI 标准参数,必须放入 extra_body
203+ # === [校验与修正逻辑] ===
204+
205+ # 1. 校验 repetition_penalty (解决 deepseek_case_22)
206+ # DashScope 要求 > 0.0,否则报错 InvalidParameter
207+ if params .repetition_penalty is not None :
208+ if params .repetition_penalty <= 0 :
209+ return JSONResponse (
210+ status_code = 400 ,
211+ content = {
212+ "code" : "InvalidParameter" ,
213+ "message" : "<400> InternalError.Algo.InvalidParameter: Repetition_penalty should be greater than 0.0"
214+ }
215+ )
216+ extra_body ["repetition_penalty" ] = params .repetition_penalty
217+
218+ # 2. 校验 top_k
219+ if params .top_k is not None :
220+ # 2.1 负数校验 (解决 deepseek_case_18)
221+ # 此时 Pydantic 已经确保它是 int,这里检查数值
222+ if params .top_k < 0 :
223+ return JSONResponse (
224+ status_code = 400 ,
225+ content = {
226+ "code" : "InvalidParameter" ,
227+ "message" : "<400> InternalError.Algo.InvalidParameter: Parameter top_k be greater than or equal to 0"
228+ }
229+ )
230+
231+ # 2.2 上限截断 (解决 deepseek_case_28)
232+ # SiliconFlow 限制 top_k 为 [1, 100]。
233+ # 如果用户传 0,通常对应 disable (即 -1) 或由模型决定,这里映射为 -1 比较稳妥
234+ # 如果用户传 > 100 (如 1025),必须截断为 100,否则上游报错
235+ if params .top_k == 0 :
236+ extra_body ["top_k" ] = - 1 # Disable
237+ elif params .top_k > 100 :
238+ extra_body ["top_k" ] = 100 # Clamp to max supported by upstream
239+ else :
240+ extra_body ["top_k" ] = params .top_k
241+ if extra_body :
242+ openai_params ["extra_body" ] = extra_body
243+
182244 if params .tools :
183245 openai_params ["tools" ] = params .tools
184246 if params .tool_choice :
185247 openai_params ["tool_choice" ] = params .tool_choice
186248
187- if params .max_tokens : openai_params ["max_tokens" ] = params .max_tokens
188- if params .stop : openai_params ["stop" ] = params .stop
249+ if params .max_tokens is not None :
250+ openai_params ["max_tokens" ] = params .max_tokens
251+ logger .debug (f"[Request] Truncation enabled: max_tokens={ params .max_tokens } " )
252+ else :
253+ logger .debug ("[Request] No max_tokens found in parameters, model will generate full response" )
254+
189255 if params .seed : openai_params ["seed" ] = params .seed
190256
257+ # === 处理 Stop Words 兼容性 ===
258+ # 优先使用 OpenAI 原生 stop 参数
259+ if params .stop :
260+ openai_params ["stop" ] = params .stop
261+ # 如果没有 stop 但有 stop_words (DashScope 格式),则进行转换
262+ elif params .stop_words :
263+ # 提取所有 mode="exclude" (默认) 的 stop_str
264+ # 注意:DashScope 的 stop_words 是 list[dict] 结构
265+ stop_list = []
266+ for sw in params .stop_words :
267+ # 仅处理 exclude 模式或未指定模式的词
268+ if sw .get ("mode" , "exclude" ) == "exclude" and "stop_str" in sw :
269+ stop_list .append (sw ["stop_str" ])
270+
271+ if stop_list :
272+ openai_params ["stop" ] = stop_list
273+
191274 try :
192275 if openai_params ["stream" ]:
193276 raw_resp = await self .client .chat .completions .with_raw_response .create (** openai_params )
@@ -204,7 +287,7 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
204287 return self ._format_unary_response (raw_resp .parse (), trace_id )
205288
206289 except APIError as e :
207- logger .error (f"Upstream API Error: { str (e )} " )
290+ logger .error (f"[request id: { initial_request_id } ] Upstream API Error: { str (e )} " )
208291 error_code = "InternalError"
209292 if isinstance (e , RateLimitError ): error_code = "Throttling.RateQuota"
210293 elif isinstance (e , AuthenticationError ): error_code = "InvalidApiKey"
@@ -359,6 +442,23 @@ def create_app() -> FastAPI:
359442 CORSMiddleware , allow_origins = ["*" ], allow_methods = ["*" ], allow_headers = ["*" ]
360443 )
361444
445+ @app .exception_handler (RequestValidationError )
446+ async def validation_exception_handler (request , exc ):
447+ # 提取第一个错误的关键信息,拼接成 DashScope 风格的错误信息
448+ error_msg = exc .errors ()[0 ].get ("msg" , "Invalid parameter" )
449+ loc = exc .errors ()[0 ].get ("loc" , [])
450+ param_name = loc [- 1 ] if loc else "unknown"
451+
452+ logger .error (f"Validation Error: { exc .errors ()} " )
453+
454+ return JSONResponse (
455+ status_code = 400 ,
456+ content = {
457+ "code" : "InvalidParameter" ,
458+ "message" : f"<400> InternalError.Algo.InvalidParameter: Parameter { param_name } check failed: { error_msg } "
459+ }
460+ )
461+
362462 @app .middleware ("http" )
363463 async def request_tracker (request : Request , call_next ):
364464 SERVER_STATE .increment_request ()
0 commit comments