Skip to content

Commit ef53f68

Browse files
committed
add
1 parent 68f13a8 commit ef53f68

File tree

1 file changed

+66
-20
lines changed

1 file changed

+66
-20
lines changed

tests/mock_server.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,58 @@
3131
)
3232
_MOCK_ENV_API_KEY = os.getenv("SILICON_FLOW_API_KEY")
3333

34-
MODEL_MAPPING = {
35-
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
36-
"deepseek-v3.1": "deepseek-ai/DeepSeek-V3.1",
37-
"deepseek-v3.2": "deepseek-ai/DeepSeek-V3.2",
38-
"deepseek-r1": "deepseek-ai/DeepSeek-R1",
39-
"default": "deepseek-ai/DeepSeek-V3",
40-
"pre-siliconflow/deepseek-v3": "deepseek-ai/DeepSeek-V3",
41-
"pre-siliconflow/deepseek-v3.1": "deepseek-ai/DeepSeek-V3.1",
42-
"pre-siliconflow/deepseek-v3.2": "deepseek-ai/DeepSeek-V3.2",
43-
"pre-siliconflow/deepseek-r1": "deepseek-ai/DeepSeek-R1",
44-
}
34+
# --- Model Abstraction & Resolution ---
35+
from dataclasses import dataclass
36+
37+
@dataclass
38+
class ModelSpec:
39+
real_model_name: str # 发送给上游 SiliconFlow 的真实模型名
40+
is_reasoning: bool = False # 是否为推理模型 (R1系列)
41+
supports_thinking: bool = False # 是否支持显式开启 thinking 参数 (V3通常不需要,R1自动开启)
42+
43+
class ModelResolver:
44+
"""
45+
模型解析器:负责把乱七八糟的模型路径解析成标准的能力配置
46+
"""
47+
def __init__(self):
48+
# 基础映射表 (保留旧逻辑兼容)
49+
self._exact_map = {
50+
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
51+
"deepseek-v3.1": "deepseek-ai/DeepSeek-V3.1",
52+
"deepseek-v3.2": "deepseek-ai/DeepSeek-V3.2",
53+
"deepseek-r1": "deepseek-ai/DeepSeek-R1",
54+
"default": "deepseek-ai/DeepSeek-V3",
55+
"pre-siliconflow/deepseek-v3": "deepseek-ai/DeepSeek-V3",
56+
"pre-siliconflow/deepseek-v3.1": "deepseek-ai/DeepSeek-V3.1",
57+
"pre-siliconflow/deepseek-v3.2": "deepseek-ai/DeepSeek-V3.2",
58+
"pre-siliconflow/deepseek-r1": "deepseek-ai/DeepSeek-R1",
59+
}
60+
61+
def resolve(self, input_model: str) -> ModelSpec:
62+
# 1. 预处理:去除多余的前缀/路径,标准化
63+
clean_name = input_model.strip()
64+
65+
# 2. 识别核心特征 (基于子字符串的模糊匹配,比精确匹配更灵活)
66+
lower_name = clean_name.lower()
67+
68+
is_r1 = "deepseek-r1" in lower_name
69+
70+
# SiliconFlow 的真实模型名通常就是传入的 path,或者需要做映射
71+
# 如果传入的是简写,查表;如果是全路径,直接用
72+
upstream_name = self._exact_map.get(lower_name, clean_name)
73+
74+
return ModelSpec(
75+
real_model_name=upstream_name,
76+
is_reasoning=is_r1,
77+
# R1 自身就是推理模型,通常 API 不接受 enable_thinking 参数或者行为不同
78+
# V3 如果未来支持 thinking,可以在这里扩展逻辑
79+
supports_thinking=not is_r1
80+
)
81+
82+
# 全局单例
83+
model_resolver = ModelResolver()
84+
85+
MODEL_MAPPING = model_resolver._exact_map
4586

4687
DUMMY_KEY = "dummy-key"
4788
MAX_NUM_MSG_CURL_DUMP = 5
@@ -166,6 +207,7 @@ def __init__(self, api_key: str, extra_headers: Optional[Dict[str, str]] = None)
166207
**kv,
167208
)
168209

210+
# [Deprecated] Use model_resolver.resolve() instead
169211
def _get_mapped_model(self, request_model: str) -> str:
170212
return MODEL_MAPPING.get(request_model, MODEL_MAPPING["default"])
171213

@@ -210,8 +252,11 @@ async def generate(
210252
force_stream: bool = False,
211253
skip_model_exist_check: bool = False,
212254
):
213-
# Model existence check
214-
if not skip_model_exist_check and req_data.model not in MODEL_MAPPING:
255+
# --- 1. 获取模型能力配置 ---
256+
model_spec = model_resolver.resolve(req_data.model)
257+
258+
# Model existence check (保留基础校验)
259+
if not skip_model_exist_check and req_data.model not in MODEL_MAPPING and not model_spec.real_model_name:
215260
return JSONResponse(
216261
status_code=400,
217262
content={
@@ -403,8 +448,10 @@ async def generate(
403448
},
404449
)
405450

406-
is_r1 = "deepseek-r1" in req_data.model or params.enable_thinking
407-
if is_r1 and params.tool_choice and isinstance(params.tool_choice, dict):
451+
# --- 2. 校验逻辑 (利用 flags 而不是字符串匹配) ---
452+
# 例如:R1 模型特判逻辑
453+
is_r1_logic = model_spec.is_reasoning or params.enable_thinking
454+
if is_r1_logic and params.tool_choice and isinstance(params.tool_choice, dict):
408455
return JSONResponse(
409456
status_code=400,
410457
content={
@@ -413,7 +460,7 @@ async def generate(
413460
},
414461
)
415462

416-
if "deepseek-r1" in req_data.model and params.enable_thinking:
463+
if model_spec.is_reasoning and params.enable_thinking:
417464
return JSONResponse(
418465
status_code=400,
419466
content={
@@ -441,7 +488,7 @@ async def generate(
441488
proxy_stop_list = list(set(proxy_stop_list))
442489

443490
# --- Request Parameters Assembly ---
444-
target_model = self._get_mapped_model(req_data.model)
491+
target_model = model_spec.real_model_name
445492
messages = self._convert_input_to_messages(req_data.input)
446493

447494
if params.enable_thinking:
@@ -571,7 +618,6 @@ async def generate(
571618

572619
# --- Execution ---
573620
try:
574-
is_r1_model = "deepseek-r1" in req_data.model
575621
if openai_params["stream"]:
576622
raw_resp = await self.client.chat.completions.with_raw_response.create(
577623
**openai_params
@@ -586,7 +632,7 @@ async def generate(
586632
trace_id,
587633
is_incremental=params.incremental_output,
588634
stop_sequences=proxy_stop_list,
589-
is_r1_model=is_r1_model,
635+
is_r1_model=model_spec.is_reasoning,
590636
),
591637
media_type="text/event-stream",
592638
headers={"X-SiliconCloud-Trace-Id": trace_id},
@@ -602,7 +648,7 @@ async def generate(
602648
raw_resp.parse(),
603649
trace_id,
604650
stop_sequences=proxy_stop_list,
605-
is_r1_model=is_r1_model,
651+
is_r1_model=model_spec.is_reasoning,
606652
)
607653

608654
except APIError as e:

0 commit comments

Comments
 (0)