3131)
3232_MOCK_ENV_API_KEY = os .getenv ("SILICON_FLOW_API_KEY" )
3333
34- MODEL_MAPPING = {
35- "deepseek-v3" : "deepseek-ai/DeepSeek-V3" ,
36- "deepseek-v3.1" : "deepseek-ai/DeepSeek-V3.1" ,
37- "deepseek-v3.2" : "deepseek-ai/DeepSeek-V3.2" ,
38- "deepseek-r1" : "deepseek-ai/DeepSeek-R1" ,
39- "default" : "deepseek-ai/DeepSeek-V3" ,
40- "pre-siliconflow/deepseek-v3" : "deepseek-ai/DeepSeek-V3" ,
41- "pre-siliconflow/deepseek-v3.1" : "deepseek-ai/DeepSeek-V3.1" ,
42- "pre-siliconflow/deepseek-v3.2" : "deepseek-ai/DeepSeek-V3.2" ,
43- "pre-siliconflow/deepseek-r1" : "deepseek-ai/DeepSeek-R1" ,
44- }
34+ # --- Model Abstraction & Resolution ---
35+ from dataclasses import dataclass
36+
37+ @dataclass
38+ class ModelSpec :
39+ real_model_name : str # 发送给上游 SiliconFlow 的真实模型名
40+ is_reasoning : bool = False # 是否为推理模型 (R1系列)
41+ supports_thinking : bool = False # 是否支持显式开启 thinking 参数 (V3通常不需要,R1自动开启)
42+
43+ class ModelResolver :
44+ """
45+ 模型解析器:负责把乱七八糟的模型路径解析成标准的能力配置
46+ """
47+ def __init__ (self ):
48+ # 基础映射表 (保留旧逻辑兼容)
49+ self ._exact_map = {
50+ "deepseek-v3" : "deepseek-ai/DeepSeek-V3" ,
51+ "deepseek-v3.1" : "deepseek-ai/DeepSeek-V3.1" ,
52+ "deepseek-v3.2" : "deepseek-ai/DeepSeek-V3.2" ,
53+ "deepseek-r1" : "deepseek-ai/DeepSeek-R1" ,
54+ "default" : "deepseek-ai/DeepSeek-V3" ,
55+ "pre-siliconflow/deepseek-v3" : "deepseek-ai/DeepSeek-V3" ,
56+ "pre-siliconflow/deepseek-v3.1" : "deepseek-ai/DeepSeek-V3.1" ,
57+ "pre-siliconflow/deepseek-v3.2" : "deepseek-ai/DeepSeek-V3.2" ,
58+ "pre-siliconflow/deepseek-r1" : "deepseek-ai/DeepSeek-R1" ,
59+ }
60+
61+ def resolve (self , input_model : str ) -> ModelSpec :
62+ # 1. 预处理:去除多余的前缀/路径,标准化
63+ clean_name = input_model .strip ()
64+
65+ # 2. 识别核心特征 (基于子字符串的模糊匹配,比精确匹配更灵活)
66+ lower_name = clean_name .lower ()
67+
68+ is_r1 = "deepseek-r1" in lower_name
69+
70+ # SiliconFlow 的真实模型名通常就是传入的 path,或者需要做映射
71+ # 如果传入的是简写,查表;如果是全路径,直接用
72+ upstream_name = self ._exact_map .get (lower_name , clean_name )
73+
74+ return ModelSpec (
75+ real_model_name = upstream_name ,
76+ is_reasoning = is_r1 ,
77+ # R1 自身就是推理模型,通常 API 不接受 enable_thinking 参数或者行为不同
78+ # V3 如果未来支持 thinking,可以在这里扩展逻辑
79+ supports_thinking = not is_r1
80+ )
81+
82+ # 全局单例
83+ model_resolver = ModelResolver ()
84+
85+ MODEL_MAPPING = model_resolver ._exact_map
4586
4687DUMMY_KEY = "dummy-key"
4788MAX_NUM_MSG_CURL_DUMP = 5
@@ -166,6 +207,7 @@ def __init__(self, api_key: str, extra_headers: Optional[Dict[str, str]] = None)
166207 ** kv ,
167208 )
168209
210+ # [Deprecated] Use model_resolver.resolve() instead
169211 def _get_mapped_model (self , request_model : str ) -> str :
170212 return MODEL_MAPPING .get (request_model , MODEL_MAPPING ["default" ])
171213
@@ -210,8 +252,11 @@ async def generate(
210252 force_stream : bool = False ,
211253 skip_model_exist_check : bool = False ,
212254 ):
213- # Model existence check
214- if not skip_model_exist_check and req_data .model not in MODEL_MAPPING :
255+ # --- 1. 获取模型能力配置 ---
256+ model_spec = model_resolver .resolve (req_data .model )
257+
258+ # Model existence check (保留基础校验)
259+ if not skip_model_exist_check and req_data .model not in MODEL_MAPPING and not model_spec .real_model_name :
215260 return JSONResponse (
216261 status_code = 400 ,
217262 content = {
@@ -403,8 +448,10 @@ async def generate(
403448 },
404449 )
405450
406- is_r1 = "deepseek-r1" in req_data .model or params .enable_thinking
407- if is_r1 and params .tool_choice and isinstance (params .tool_choice , dict ):
451+ # --- 2. 校验逻辑 (利用 flags 而不是字符串匹配) ---
452+ # 例如:R1 模型特判逻辑
453+ is_r1_logic = model_spec .is_reasoning or params .enable_thinking
454+ if is_r1_logic and params .tool_choice and isinstance (params .tool_choice , dict ):
408455 return JSONResponse (
409456 status_code = 400 ,
410457 content = {
@@ -413,7 +460,7 @@ async def generate(
413460 },
414461 )
415462
416- if "deepseek-r1" in req_data . model and params .enable_thinking :
463+ if model_spec . is_reasoning and params .enable_thinking :
417464 return JSONResponse (
418465 status_code = 400 ,
419466 content = {
@@ -441,7 +488,7 @@ async def generate(
441488 proxy_stop_list = list (set (proxy_stop_list ))
442489
443490 # --- Request Parameters Assembly ---
444- target_model = self . _get_mapped_model ( req_data . model )
491+ target_model = model_spec . real_model_name
445492 messages = self ._convert_input_to_messages (req_data .input )
446493
447494 if params .enable_thinking :
@@ -571,7 +618,6 @@ async def generate(
571618
572619 # --- Execution ---
573620 try :
574- is_r1_model = "deepseek-r1" in req_data .model
575621 if openai_params ["stream" ]:
576622 raw_resp = await self .client .chat .completions .with_raw_response .create (
577623 ** openai_params
@@ -586,7 +632,7 @@ async def generate(
586632 trace_id ,
587633 is_incremental = params .incremental_output ,
588634 stop_sequences = proxy_stop_list ,
589- is_r1_model = is_r1_model ,
635+ is_r1_model = model_spec . is_reasoning ,
590636 ),
591637 media_type = "text/event-stream" ,
592638 headers = {"X-SiliconCloud-Trace-Id" : trace_id },
@@ -602,7 +648,7 @@ async def generate(
602648 raw_resp .parse (),
603649 trace_id ,
604650 stop_sequences = proxy_stop_list ,
605- is_r1_model = is_r1_model ,
651+ is_r1_model = model_spec . is_reasoning ,
606652 )
607653
608654 except APIError as e :
0 commit comments