Skip to content

Commit 16996eb

Browse files
committed
Update mock_server.py
1 parent a96790e commit 16996eb

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

tests/mock_server.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@ def snapshot(self):
8787

8888
class Message(BaseModel):
8989
role: str
90-
content: str
90+
content: Optional[str] = "" # tool_calls 时 content 可能为空
91+
tool_calls: Optional[List[Dict[str, Any]]] = None
92+
tool_call_id: Optional[str] = None
93+
name: Optional[str] = None
94+
95+
model_config = ConfigDict(extra='allow')
9196

9297
class InputData(BaseModel):
9398
messages: Optional[List[Message]] = None
@@ -164,6 +169,32 @@ def _convert_input_to_messages(self, input_data: InputData) -> List[Dict[str, st
164169
return messages
165170

166171
async def generate(self, req_data: GenerationRequest, initial_request_id: str):
172+
# 0. 提前拦截无效的 Tool Call 链 (Strict Validation)
173+
# 必须在 _convert_input_to_messages 之前执行,且依赖 req_data.input.messages 的原始结构
174+
if req_data.input.messages:
175+
msgs = req_data.input.messages
176+
for idx, msg in enumerate(msgs):
177+
# 检查是否是发起调用的 assistant 消息
178+
# 注意:这里依赖 Pydantic 模型中已定义 tool_calls 或者是 extra="allow"
179+
has_tool_calls = getattr(msg, "tool_calls", None)
180+
181+
if msg.role == "assistant" and has_tool_calls:
182+
# 检查下一条消息
183+
next_idx = idx + 1
184+
if next_idx < len(msgs):
185+
next_msg = msgs[next_idx]
186+
# 规则:Assistant call 之后必须紧接 tool 消息
187+
if next_msg.role != "tool":
188+
logger.warning(f"Interceptor caught invalid tool chain at index {next_idx}")
189+
return JSONResponse(
190+
status_code=400,
191+
content={
192+
"code": "InvalidParameter",
193+
# 精确匹配错误格式
194+
"message": f"<400> InternalError.Algo.InvalidParameter: An assistant message with \"tool_calls\" must be followed by tool messages responding to each \"tool_call_id\". The following tool_call_ids did not have response messages: message[{next_idx}].role"
195+
}
196+
)
197+
167198
params = req_data.parameters
168199

169200
# Validation: Tools require message format

0 commit comments

Comments
 (0)