Skip to content

Commit 22e751d

Browse files
committed
refactor: streamline update models script
1 parent 9cd8ca6 commit 22e751d

File tree

1 file changed

+98
-114
lines changed

1 file changed

+98
-114
lines changed

scripts/update_models.py

Lines changed: 98 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
├─────────────────────────────────────────────────────────────────────┤
2323
│ id → model (模型 ID) │
2424
│ name → label.zh_Hans / label.en_US (双语名称) │
25-
│ features → features (特性列表,需转换) │
26-
│ ├─ "tools" → "tool-call" │
27-
│ ├─ "vision" → "vision" │
28-
│ └─ "tool-call" → 自动添加 "stream-tool-call" │
25+
│ input_modalities → features (特性列表) │
26+
│ ├─ 默认 → ["tool-call", "stream-tool-call"] │
27+
│ └─ 含 image → 额外添加 "vision" │
2928
│ model_constraints → model_properties │
3029
│ └─ context_length → context_size (默认 65536, 64k) │
3130
│ 默认参数 → parameter_rules │
@@ -42,8 +41,8 @@
4241
google/gemini-2.5-flash → google-gemini-2.5-flash.yaml
4342
4443
特性默认值:
45-
- 如果模型没有声明任何特性,默认添加: ["tool-call", "stream-tool-call"]
46-
- 如果有 tool-call 特性,自动添加 stream-tool-call
44+
- 始终输出 ["tool-call", "stream-tool-call"]
45+
- 如果输入模态包含 image,额外添加 "vision"
4746
4847
文件管理策略:
4948
- 新增:创建新的 YAML 文件
@@ -65,7 +64,7 @@
6564
import yaml
6665
import requests
6766
from pathlib import Path
68-
from typing import List, Dict, Any, Optional
67+
from typing import List, Dict, Any
6968

7069
# 七牛云市场 API 端点
7170
MARKET_API_URL = "https://openai.sufy.com/v1/market/models?overseas=true"
@@ -74,7 +73,27 @@
7473
SCRIPT_DIR = Path(__file__).parent
7574
PROJECT_ROOT = SCRIPT_DIR.parent
7675
MODELS_DIR = PROJECT_ROOT / "ai-models-provider" / "models" / "llm"
77-
POSITION_FILE = MODELS_DIR / "_position.yaml"
76+
POSITION_FILENAME = "_position.yaml"
77+
POSITION_FILE = MODELS_DIR / POSITION_FILENAME
78+
VISION_MODALITIES = {"image"}
79+
DEFAULT_CONTEXT_LENGTH = 65_536
80+
DEFAULT_FEATURES = ["tool-call", "stream-tool-call"]
81+
PARAMETER_RULE_TEMPLATES = [
82+
{
83+
"name": "temperature",
84+
"use_template": "temperature",
85+
},
86+
{
87+
"name": "top_p",
88+
"use_template": "top_p",
89+
},
90+
{
91+
"name": "max_tokens",
92+
"use_template": "max_tokens",
93+
},
94+
]
95+
96+
ModelInfo = Dict[str, Any]
7897

7998
# CI 环境检测
8099
IS_CI = os.getenv("CI", "").lower() in ("true", "1", "yes")
@@ -83,7 +102,12 @@
83102
models_with_missing_fields = []
84103

85104

86-
def is_llm_model(model_info: Dict[str, Any]) -> bool:
105+
def has_vision_capability(input_modalities: List[str]) -> bool:
106+
"""Return True when the model supports any non-text modality we treat as vision."""
107+
return any(modality in VISION_MODALITIES for modality in (input_modalities or []))
108+
109+
110+
def is_llm_model(model_info: ModelInfo) -> bool:
87111
"""
88112
判断是否为文本 LLM 模型
89113
@@ -101,7 +125,7 @@ def is_llm_model(model_info: Dict[str, Any]) -> bool:
101125
return "text" in input_modalities and "text" in output_modalities
102126

103127

104-
def get_model_features(model_info: Dict[str, Any]) -> List[str]:
128+
def get_model_features(model_info: ModelInfo) -> List[str]:
105129
"""
106130
根据模型信息获取支持的特性
107131
@@ -111,33 +135,18 @@ def get_model_features(model_info: Dict[str, Any]) -> List[str]:
111135
Returns:
112136
特性列表
113137
"""
114-
features = []
115-
116-
# 从 features 字段获取
117-
api_features = model_info.get("features", [])
118-
119-
# 映射 API 特性到 Dify 特性
120-
feature_mapping = {
121-
"tools": "tool-call",
122-
"vision": "vision",
123-
}
138+
# 默认认为支持工具调用能力
139+
features = list(DEFAULT_FEATURES)
124140

125-
for api_feature in api_features:
126-
if api_feature in feature_mapping:
127-
features.append(feature_mapping[api_feature])
128-
129-
# 如果支持工具调用,默认也支持流式工具调用
130-
if "tool-call" in features:
131-
features.append("stream-tool-call")
132-
133-
# 如果没有任何特性,默认添加工具调用(大多数 LLM 都支持)
134-
if not features:
135-
features = ["tool-call", "stream-tool-call"]
141+
architecture = model_info.get("architecture", {})
142+
input_modalities = architecture.get("input_modalities", []) or []
143+
if has_vision_capability(input_modalities) and "vision" not in features:
144+
features.append("vision")
136145

137146
return features
138147

139148

140-
def get_model_context_size(model_info: Dict[str, Any]) -> int:
149+
def get_model_context_size(model_info: ModelInfo) -> int:
141150
"""
142151
获取模型上下文大小
143152
@@ -163,10 +172,10 @@ def get_model_context_size(model_info: Dict[str, Any]) -> int:
163172
if IS_CI:
164173
models_with_missing_fields.append(error_msg)
165174

166-
return 65536
175+
return DEFAULT_CONTEXT_LENGTH
167176

168177

169-
def generate_model_yaml(model_info: Dict[str, Any]) -> Dict[str, Any]:
178+
def generate_model_yaml(model_info: ModelInfo) -> Dict[str, Any]:
170179
"""
171180
生成模型的 YAML 配置
172181
@@ -192,26 +201,13 @@ def generate_model_yaml(model_info: Dict[str, Any]) -> Dict[str, Any]:
192201
"mode": "chat",
193202
"context_size": get_model_context_size(model_info),
194203
},
195-
"parameter_rules": [
196-
{
197-
"name": "temperature",
198-
"use_template": "temperature",
199-
},
200-
{
201-
"name": "top_p",
202-
"use_template": "top_p",
203-
},
204-
{
205-
"name": "max_tokens",
206-
"use_template": "max_tokens",
207-
},
208-
],
204+
"parameter_rules": [template.copy() for template in PARAMETER_RULE_TEMPLATES],
209205
}
210206

211207
return config
212208

213209

214-
def fetch_models_from_api() -> List[Dict[str, Any]]:
210+
def fetch_models_from_api() -> List[ModelInfo]:
215211
"""
216212
从七牛云市场 API 获取模型列表
217213
@@ -279,16 +275,6 @@ def fetch_models_from_api() -> List[Dict[str, Any]]:
279275
return []
280276

281277

282-
def get_existing_models() -> List[str]:
283-
"""获取现有的模型文件名列表(不包含扩展名)"""
284-
models = []
285-
for file in MODELS_DIR.glob("*.yaml"):
286-
if file.name not in ["_position.yaml"]:
287-
model_filename = file.stem
288-
models.append(model_filename)
289-
return models
290-
291-
292278
def sanitize_filename(model_id: str) -> str:
293279
"""
294280
将模型 ID 转换为合法的文件名
@@ -313,16 +299,15 @@ def sanitize_filename(model_id: str) -> str:
313299
return filename
314300

315301

316-
def get_existing_models() -> List[str]:
317-
"""获取现有的模型列表"""
318-
models = []
319-
for file in MODELS_DIR.glob("*.yaml"):
320-
if file.name not in ["_position.yaml"]:
321-
model_id = file.stem
322-
models.append(model_id)
323-
return models
302+
def get_existing_model_files() -> Dict[str, Path]:
303+
"""Return a mapping of sanitized model IDs to their YAML file paths."""
304+
return {
305+
file.stem: file
306+
for file in MODELS_DIR.glob("*.yaml")
307+
if file.name != POSITION_FILENAME
308+
}
324309

325-
def update_model_files(models: List[Dict[str, Any]]) -> tuple[List[str], List[str], List[str]]:
310+
def update_model_files(models: List[ModelInfo]) -> tuple[List[str], List[str], List[str]]:
326311
"""
327312
更新模型配置文件
328313
@@ -332,24 +317,22 @@ def update_model_files(models: List[Dict[str, Any]]) -> tuple[List[str], List[st
332317
Returns:
333318
(新增的模型 ID 列表, 更新的模型 ID 列表, 删除的模型文件名列表)
334319
"""
335-
existing_models = set(get_existing_models())
336-
new_model_filenames = set(sanitize_filename(m["id"]) for m in models)
337-
338-
added = []
339-
updated = []
340-
removed = []
320+
existing_models = get_existing_model_files()
321+
new_model_filenames = set()
322+
added: List[str] = []
323+
updated: List[str] = []
324+
removed: List[str] = []
341325

342326
# 新增或更新模型
343327
for model_info in models:
344328
model_id = model_info["id"]
345329
model_name = model_info.get("name", model_id)
346330
filename = sanitize_filename(model_id)
347331
file_path = MODELS_DIR / f"{filename}.yaml"
332+
new_model_filenames.add(filename)
348333

349-
# 生成配置
350334
config = generate_model_yaml(model_info)
351335

352-
# 判断是新增还是更新
353336
if filename not in existing_models:
354337
print(f" + 新增: {model_id}")
355338
if model_name != model_id:
@@ -359,22 +342,20 @@ def update_model_files(models: List[Dict[str, Any]]) -> tuple[List[str], List[st
359342
print(f" ↻ 更新: {model_id}")
360343
updated.append(model_id)
361344

362-
# 写入 YAML 文件
363345
with open(file_path, "w", encoding="utf-8") as f:
364346
yaml.dump(config, f, allow_unicode=True, sort_keys=False, default_flow_style=False)
365347

366348
# 删除不在新列表中的模型(全量更新)
367-
for filename in existing_models:
349+
for filename, file_path in existing_models.items():
368350
if filename not in new_model_filenames:
369-
file_path = MODELS_DIR / f"{filename}.yaml"
370351
print(f" - 删除: {filename}")
371352
file_path.unlink()
372353
removed.append(filename)
373354

374355
return added, updated, removed
375356

376357

377-
def update_position_file(models: List[Dict[str, Any]]):
358+
def update_position_file(models: List[ModelInfo]):
378359
"""
379360
更新 _position.yaml 文件
380361
@@ -411,6 +392,40 @@ def update_position_file(models: List[Dict[str, Any]]):
411392
print(f"✓ 已更新 _position.yaml,共 {len(ordered_models)} 个模型")
412393

413394

395+
def summarize_changes(added: List[str], updated: List[str], removed: List[str], total: int) -> None:
396+
print()
397+
print("=" * 70)
398+
print("更新完成!")
399+
print(f" 新增模型: {len(added)} 个")
400+
print(f" 更新模型: {len(updated)} 个")
401+
print(f" 删除模型: {len(removed)} 个")
402+
print(f" 总计模型: {total} 个")
403+
print("=" * 70)
404+
405+
406+
def enforce_ci_requirements() -> None:
407+
if IS_CI and models_with_missing_fields:
408+
print()
409+
print("=" * 70)
410+
print("✗ 错误:在 CI 环境中检测到模型缺少必需字段")
411+
print("=" * 70)
412+
for error in models_with_missing_fields:
413+
print(f" ✗ {error}")
414+
print()
415+
print(f"共 {len(models_with_missing_fields)} 个模型缺少必需字段")
416+
print("请联系 API 提供方修复数据完整性问题")
417+
print("=" * 70)
418+
sys.exit(1)
419+
420+
421+
def exit_with_change_status(added: List[str], updated: List[str], removed: List[str]) -> None:
422+
if added or updated or removed:
423+
sys.exit(0)
424+
print()
425+
print("提示:没有检测到模型变更")
426+
sys.exit(1)
427+
428+
414429
def main():
415430
"""主函数"""
416431
print("=" * 70)
@@ -440,40 +455,9 @@ def main():
440455

441456
# 更新 position 文件
442457
update_position_file(models)
443-
444-
print()
445-
print("=" * 70)
446-
print("更新完成!")
447-
print(f" 新增模型: {len(added)} 个")
448-
print(f" 更新模型: {len(updated)} 个")
449-
print(f" 删除模型: {len(removed)} 个")
450-
print(f" 总计模型: {len(models)} 个")
451-
print("=" * 70)
452-
453-
# 在 CI 环境中,如果有模型缺少必需字段,报错并退出
454-
if IS_CI and models_with_missing_fields:
455-
print()
456-
print("=" * 70)
457-
print("✗ 错误:在 CI 环境中检测到模型缺少必需字段")
458-
print("=" * 70)
459-
for error in models_with_missing_fields:
460-
print(f" ✗ {error}")
461-
print()
462-
print(f"共 {len(models_with_missing_fields)} 个模型缺少必需字段")
463-
print("请联系 API 提供方修复数据完整性问题")
464-
print("=" * 70)
465-
sys.exit(1)
466-
467-
# 返回退出码
468-
# 0: 成功且有变更
469-
# 1: 成功但无变更(可用于 CI/CD 判断是否需要提交)
470-
has_changes = bool(added or updated or removed)
471-
if has_changes:
472-
sys.exit(0)
473-
else:
474-
print()
475-
print("提示:没有检测到模型变更")
476-
sys.exit(1)
458+
summarize_changes(added, updated, removed, len(models))
459+
enforce_ci_requirements()
460+
exit_with_change_status(added, updated, removed)
477461

478462

479463
if __name__ == "__main__":

0 commit comments

Comments
 (0)