Skip to content

Commit 6f55997

Browse files
committed
Update mock_server.py
1 parent ef53f68 commit 6f55997

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

tests/mock_server.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,21 @@
3434
# --- Model Abstraction & Resolution ---
3535
from dataclasses import dataclass
3636

37+
3738
@dataclass
3839
class ModelSpec:
39-
real_model_name: str # 发送给上游 SiliconFlow 的真实模型名
40-
is_reasoning: bool = False # 是否为推理模型 (R1系列)
41-
supports_thinking: bool = False # 是否支持显式开启 thinking 参数 (V3通常不需要,R1自动开启)
40+
real_model_name: str # 发送给上游 SiliconFlow 的真实模型名
41+
is_reasoning: bool = False # 是否为推理模型 (R1系列)
42+
supports_thinking: bool = (
43+
False # 是否支持显式开启 thinking 参数 (V3通常不需要,R1自动开启)
44+
)
45+
4246

4347
class ModelResolver:
4448
"""
4549
模型解析器:负责把乱七八糟的模型路径解析成标准的能力配置
4650
"""
51+
4752
def __init__(self):
4853
# 基础映射表 (保留旧逻辑兼容)
4954
self._exact_map = {
@@ -76,9 +81,10 @@ def resolve(self, input_model: str) -> ModelSpec:
7681
is_reasoning=is_r1,
7782
# R1 自身就是推理模型,通常 API 不接受 enable_thinking 参数或者行为不同
7883
# V3 如果未来支持 thinking,可以在这里扩展逻辑
79-
supports_thinking=not is_r1
84+
supports_thinking=not is_r1,
8085
)
8186

87+
8288
# 全局单例
8389
model_resolver = ModelResolver()
8490

@@ -256,7 +262,11 @@ async def generate(
256262
model_spec = model_resolver.resolve(req_data.model)
257263

258264
# Model existence check (保留基础校验)
259-
if not skip_model_exist_check and req_data.model not in MODEL_MAPPING and not model_spec.real_model_name:
265+
if (
266+
not skip_model_exist_check
267+
and req_data.model not in MODEL_MAPPING
268+
and not model_spec.real_model_name
269+
):
260270
return JSONResponse(
261271
status_code=400,
262272
content={
@@ -655,8 +665,30 @@ async def generate(
655665
logger.error(
656666
f"[request id: {initial_request_id}] Upstream API Error: {str(e)}"
657667
)
668+
669+
# Default error code
658670
error_code = "InternalError"
659-
if isinstance(e, RateLimitError):
671+
672+
# Handle HTTP 400 errors (typically parameter errors or model not found)
673+
if e.status_code == 400:
674+
error_code = "InvalidParameter"
675+
# Check if the error message contains model not found indication
676+
error_msg = str(e)
677+
if (
678+
"Model does not exist" in error_msg
679+
or "model not found" in error_msg.lower()
680+
):
681+
# Return the exact message expected by tests
682+
return JSONResponse(
683+
status_code=400,
684+
content={
685+
"code": "InvalidParameter",
686+
"message": "Model not exist.",
687+
"request_id": initial_request_id,
688+
},
689+
)
690+
691+
elif isinstance(e, RateLimitError):
660692
error_code = "Throttling.RateQuota"
661693
elif isinstance(e, AuthenticationError):
662694
error_code = "InvalidApiKey"

0 commit comments

Comments
 (0)