Skip to content

Commit a96790e

Browse files
committed
Update mock_server.py
Update mock_server.py Update mock_server.py Update mock_server.py Update mock_server.py Update mock_server.py Update mock_server.py Update mock_server.py Update mock_server.py Update mock_server.py
1 parent 49a2a88 commit a96790e

File tree

1 file changed

+105
-5
lines changed

1 file changed

+105
-5
lines changed

tests/mock_server.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
import uvicorn
1313
from fastapi import FastAPI, HTTPException, Request, Header
14+
from fastapi.exceptions import RequestValidationError
1415
from fastapi.responses import StreamingResponse, JSONResponse
1516
from fastapi.middleware.cors import CORSMiddleware
16-
from pydantic import BaseModel, Field
17+
from pydantic import BaseModel, Field, AliasChoices, ConfigDict
1718
from openai import AsyncOpenAI, APIError, RateLimitError, AuthenticationError
1819

1920
# --- [System Configuration] ---
@@ -100,12 +101,24 @@ class Parameters(BaseModel):
100101
top_p: Optional[float] = 0.8
101102
top_k: Optional[int] = None
102103
seed: Optional[int] = 1234
103-
max_tokens: Optional[int] = None
104+
max_tokens: Optional[int] = Field(
105+
None,
106+
validation_alias=AliasChoices("max_tokens", "max_length")
107+
)
108+
frequency_penalty: Optional[float] = 0.0
109+
presence_penalty: Optional[float] = 0.0
110+
repetition_penalty: Optional[float] = 1.0
111+
112+
# OpenAI原生格式
104113
stop: Optional[Union[str, List[str]]] = None
114+
# DashScope兼容格式
115+
stop_words: Optional[List[Dict[str, Any]]] = None
105116
enable_thinking: bool = False
106117
thinking_budget: Optional[int] = None
107118
tools: Optional[List[Dict[str, Any]]] = None
108119
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
120+
# 显式开启从属性名读取
121+
model_config = ConfigDict(populate_by_name=True, protected_namespaces=())
109122

110123
class GenerationRequest(BaseModel):
111124
model: str
@@ -179,15 +192,85 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
179192
"stream": params.incremental_output or params.enable_thinking,
180193
}
181194

195+
if params.frequency_penalty is not None:
196+
openai_params["frequency_penalty"] = params.frequency_penalty
197+
198+
if params.presence_penalty is not None:
199+
openai_params["presence_penalty"] = params.presence_penalty
200+
201+
extra_body = {}
202+
# 不是 OpenAI 标准参数,必须放入 extra_body
203+
# === [校验与修正逻辑] ===
204+
205+
# 1. 校验 repetition_penalty (解决 deepseek_case_22)
206+
# DashScope 要求 > 0.0,否则报错 InvalidParameter
207+
if params.repetition_penalty is not None:
208+
if params.repetition_penalty <= 0:
209+
return JSONResponse(
210+
status_code=400,
211+
content={
212+
"code": "InvalidParameter",
213+
"message": "<400> InternalError.Algo.InvalidParameter: Repetition_penalty should be greater than 0.0"
214+
}
215+
)
216+
extra_body["repetition_penalty"] = params.repetition_penalty
217+
218+
# 2. 校验 top_k
219+
if params.top_k is not None:
220+
# 2.1 负数校验 (解决 deepseek_case_18)
221+
# 此时 Pydantic 已经确保它是 int,这里检查数值
222+
if params.top_k < 0:
223+
return JSONResponse(
224+
status_code=400,
225+
content={
226+
"code": "InvalidParameter",
227+
"message": "<400> InternalError.Algo.InvalidParameter: Parameter top_k be greater than or equal to 0"
228+
}
229+
)
230+
231+
# 2.2 上限截断 (解决 deepseek_case_28)
232+
# SiliconFlow 限制 top_k 为 [1, 100]。
233+
# 如果用户传 0,通常对应 disable (即 -1) 或由模型决定,这里映射为 -1 比较稳妥
234+
# 如果用户传 > 100 (如 1025),必须截断为 100,否则上游报错
235+
if params.top_k == 0:
236+
extra_body["top_k"] = -1 # Disable
237+
elif params.top_k > 100:
238+
extra_body["top_k"] = 100 # Clamp to max supported by upstream
239+
else:
240+
extra_body["top_k"] = params.top_k
241+
if extra_body:
242+
openai_params["extra_body"] = extra_body
243+
182244
if params.tools:
183245
openai_params["tools"] = params.tools
184246
if params.tool_choice:
185247
openai_params["tool_choice"] = params.tool_choice
186248

187-
if params.max_tokens: openai_params["max_tokens"] = params.max_tokens
188-
if params.stop: openai_params["stop"] = params.stop
249+
if params.max_tokens is not None:
250+
openai_params["max_tokens"] = params.max_tokens
251+
logger.debug(f"[Request] Truncation enabled: max_tokens={params.max_tokens}")
252+
else:
253+
logger.debug("[Request] No max_tokens found in parameters, model will generate full response")
254+
189255
if params.seed: openai_params["seed"] = params.seed
190256

257+
# === 处理 Stop Words 兼容性 ===
258+
# 优先使用 OpenAI 原生 stop 参数
259+
if params.stop:
260+
openai_params["stop"] = params.stop
261+
# 如果没有 stop 但有 stop_words (DashScope 格式),则进行转换
262+
elif params.stop_words:
263+
# 提取所有 mode="exclude" (默认) 的 stop_str
264+
# 注意:DashScope 的 stop_words 是 list[dict] 结构
265+
stop_list = []
266+
for sw in params.stop_words:
267+
# 仅处理 exclude 模式或未指定模式的词
268+
if sw.get("mode", "exclude") == "exclude" and "stop_str" in sw:
269+
stop_list.append(sw["stop_str"])
270+
271+
if stop_list:
272+
openai_params["stop"] = stop_list
273+
191274
try:
192275
if openai_params["stream"]:
193276
raw_resp = await self.client.chat.completions.with_raw_response.create(**openai_params)
@@ -204,7 +287,7 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
204287
return self._format_unary_response(raw_resp.parse(), trace_id)
205288

206289
except APIError as e:
207-
logger.error(f"Upstream API Error: {str(e)}")
290+
logger.error(f"[request id: {initial_request_id}] Upstream API Error: {str(e)}")
208291
error_code = "InternalError"
209292
if isinstance(e, RateLimitError): error_code = "Throttling.RateQuota"
210293
elif isinstance(e, AuthenticationError): error_code = "InvalidApiKey"
@@ -359,6 +442,23 @@ def create_app() -> FastAPI:
359442
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
360443
)
361444

445+
@app.exception_handler(RequestValidationError)
446+
async def validation_exception_handler(request, exc):
447+
# 提取第一个错误的关键信息,拼接成 DashScope 风格的错误信息
448+
error_msg = exc.errors()[0].get("msg", "Invalid parameter")
449+
loc = exc.errors()[0].get("loc", [])
450+
param_name = loc[-1] if loc else "unknown"
451+
452+
logger.error(f"Validation Error: {exc.errors()}")
453+
454+
return JSONResponse(
455+
status_code=400,
456+
content={
457+
"code": "InvalidParameter",
458+
"message": f"<400> InternalError.Algo.InvalidParameter: Parameter {param_name} check failed: {error_msg}"
459+
}
460+
)
461+
362462
@app.middleware("http")
363463
async def request_tracker(request: Request, call_next):
364464
SERVER_STATE.increment_request()

0 commit comments

Comments
 (0)