Skip to content

Commit 475e86b

Browse files
committed
Update mock_server.py
Update mock_server.py
1 parent 4b16fa4 commit 475e86b

File tree

1 file changed

+23
-26
lines changed

1 file changed

+23
-26
lines changed

tests/mock_server.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
}
4444
DUMMY_KEY = "dummy-key"
4545

46+
MAX_NUM_MSG_CURL_DUMP = 5 # Controls the max messages shown in curl logs
47+
4648
# --- [Shared State] ---
4749

4850

@@ -201,24 +203,19 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
201203

202204
# === [thinking_budget 校验] ===
203205
if params.thinking_budget is not None:
204-
# 必须大于 0 (具体限制取决于上游,但负数肯定是无效的)
205206
if params.thinking_budget <= 0:
206207
return JSONResponse(
207208
status_code=400,
208209
content={
209210
"code": "InvalidParameter",
210-
# 保持与代码库中其他错误一致的格式
211211
"message": "<400> InternalError.Algo.InvalidParameter: thinking_budget should be greater than 0",
212212
},
213213
)
214214
# ===================================
215215

216-
# === [新增 response_format 校验逻辑] ===
216+
# === [response_format 校验] ===
217217
if params.response_format:
218218
rf_type = params.response_format.get("type")
219-
220-
# 校验 type 值是否合法
221-
# 允许的值通常为 json_object 或 text (根据测试用例报错信息)
222219
if rf_type and rf_type not in ["json_object", "text"]:
223220
return JSONResponse(
224221
status_code=400,
@@ -228,7 +225,6 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
228225
},
229226
)
230227

231-
# 校验 json_object 下不能包含 json_schema
232228
if rf_type == "json_object" and "json_schema" in params.response_format:
233229
return JSONResponse(
234230
status_code=400,
@@ -239,17 +235,13 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
239235
)
240236
# -----------------------------------------------
241237

242-
# 0. 提前拦截无效的 Tool Call 链 (Strict Validation)
243-
# 必须在 _convert_input_to_messages 之前执行,且依赖 req_data.input.messages 的原始结构
238+
# 0. 提前拦截无效的 Tool Call 链
244239
if req_data.input.messages:
245240
msgs = req_data.input.messages
246241
for idx, msg in enumerate(msgs):
247-
# 检查是否是发起调用的 assistant 消息
248-
# 注意:这里依赖 Pydantic 模型中已定义 tool_calls 或者是 extra="allow"
249242
has_tool_calls = getattr(msg, "tool_calls", None)
250243

251244
if msg.role == "assistant" and has_tool_calls:
252-
# 检查下一条消息
253245
next_idx = idx + 1
254246
if next_idx < len(msgs):
255247
next_msg = msgs[next_idx]
@@ -262,12 +254,10 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
262254
status_code=400,
263255
content={
264256
"code": "InvalidParameter",
265-
# 精确匹配错误格式
266257
"message": f'<400> InternalError.Algo.InvalidParameter: An assistant message with "tool_calls" must be followed by tool messages responding to each "tool_call_id". The following tool_call_ids did not have response messages: message[{next_idx}].role',
267258
},
268259
)
269260

270-
# --- Check if Tool message is preceded by Assistant (NEW LOGIC) ---
271261
if msg.role == "tool":
272262
is_orphan = False
273263
if idx == 0:
@@ -288,7 +278,6 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
288278
"message": '<400> InternalError.Algo.InvalidParameter: messages with role "tool" must be a response to a preceeding message with "tool_calls".',
289279
},
290280
)
291-
# -----------------------------------------------------
292281

293282
# Validation: Tools require message format
294283
if params.tools and params.result_format != "message":
@@ -322,10 +311,8 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
322311
"stream": params.incremental_output or params.enable_thinking,
323312
}
324313

325-
# === [将合法的 response_format 加入请求参数] ===
326314
if params.response_format:
327315
openai_params["response_format"] = params.response_format
328-
# ----------------------------------------------------
329316

330317
if params.frequency_penalty is not None:
331318
openai_params["frequency_penalty"] = params.frequency_penalty
@@ -368,9 +355,9 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
368355
# 如果用户传 0,通常对应 disable (即 -1) 或由模型决定,这里映射为 -1 比较稳妥
369356
# 如果用户传 > 100 (如 1025),必须截断为 100,否则上游报错
370357
if params.top_k == 0:
371-
extra_body["top_k"] = -1 # Disable
358+
extra_body["top_k"] = -1
372359
elif params.top_k > 100:
373-
extra_body["top_k"] = 100 # Clamp to max supported by upstream
360+
extra_body["top_k"] = 100
374361
else:
375362
extra_body["top_k"] = params.top_k
376363

@@ -379,7 +366,6 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
379366
# 确保只有在开启思考时才透传 budget,且前面已经校验过 >0
380367
# 如果上游明确支持 thinking_budget 字段:
381368
extra_body["thinking_budget"] = params.thinking_budget
382-
# ===============================================
383369

384370
if extra_body:
385371
openai_params["extra_body"] = extra_body
@@ -427,26 +413,37 @@ async def generate(self, req_data: GenerationRequest, initial_request_id: str):
427413
if stop_list:
428414
openai_params["stop"] = stop_list
429415

430-
# --- 生成 Curl 命令 (过滤掉非法 Header) ---
416+
# --- 生成 Curl 命令 (过滤掉非法 Header & Truncate Log) ---
431417
# 1. 基础 Header
432418
curl_headers = [
433419
'-H "Authorization: Bearer ${SILICONFLOW_API_KEY}"',
434420
"-H 'Content-Type: application/json'",
435421
]
436422

437-
# 2. 补充透传的 Header (过滤掉 Omit 对象和不需要的字段)
438-
# default_headers 包含了初始化时传入的 extra_headers
423+
# 2. 补充透传的 Header
439424
skip_keys = {"authorization", "content-type", "content-length", "host"}
440425
for k, v in self.client.default_headers.items():
441-
# 过滤掉 OpenAI 内部的 Omit 对象 和 系统自动生成的头
442426
if k.lower() not in skip_keys and not str(v).startswith("<openai."):
443427
curl_headers.append(f"-H '{k}: {v}'")
444428

445-
# 3. 组装命令
429+
# 3. 准备 Logging 专用 Payload (如果消息太长则截断)
430+
log_payload = openai_params.copy()
431+
msg_list = log_payload.get("messages", [])
432+
if len(msg_list) > MAX_NUM_MSG_CURL_DUMP:
433+
# 截断展示: 只保留前 N 条,并添加说明
434+
truncated_msgs = msg_list[:MAX_NUM_MSG_CURL_DUMP]
435+
truncated_msgs.append(
436+
{
437+
"_log_truncation": f"... {len(msg_list) - MAX_NUM_MSG_CURL_DUMP} messages omitted for brevity ..."
438+
}
439+
)
440+
log_payload["messages"] = truncated_msgs
441+
442+
# 4. 组装命令
446443
curl_cmd = (
447444
f"curl -X POST {SILICON_FLOW_BASE_URL}/chat/completions \\\n "
448445
+ " \\\n ".join(curl_headers)
449-
+ f" \\\n -d '{json.dumps(openai_params, ensure_ascii=False)}'"
446+
+ f" \\\n -d '{json.dumps(log_payload, ensure_ascii=False)}'"
450447
)
451448

452449
logger.debug(f"[Curl Command]\n{curl_cmd}")

0 commit comments

Comments
 (0)