Skip to content

Commit afc0664

Browse files
committed
feat: 更新模型配置和选择逻辑,修改为通过 model_spec 统一指定模型格式
1 parent be40799 commit afc0664

File tree

9 files changed

+159
-46
lines changed

9 files changed

+159
-46
lines changed

docs/intro/model-config.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@
2121

2222
<<< @/../.env.template#model_provider{bash 2}
2323

24+
### 默认对话模型格式
25+
26+
系统的默认对话模型通过配置项 `default_model` 指定,格式统一为 `模型提供商/模型名称`,例如:
27+
28+
```yaml
29+
default_model: siliconflow/deepseek-ai/DeepSeek-V3.2-Exp
30+
```
31+
32+
在 Web 界面中选择模型时也会自动按照这一格式保存,无需手动拆分提供商和模型名称。
33+
2434
2535
::: tip 免费获取 API Key
2636
[硅基流动](https://cloud.siliconflow.cn/i/Eo5yTHGJ) 注册即送 14 元额度,支持多种开源模型。

server/routers/chat_router.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ async def set_default_agent(request_data: dict = Body(...), current_user=Depends
8585
async def call(query: str = Body(...), meta: dict = Body(None), current_user: User = Depends(get_required_user)):
8686
"""调用模型进行简单问答(需要登录)"""
8787
meta = meta or {}
88-
model = select_model(model_provider=meta.get("model_provider"), model_name=meta.get("model_name"))
88+
model = select_model(
89+
model_provider=meta.get("model_provider"),
90+
model_name=meta.get("model_name"),
91+
model_spec=meta.get("model_spec") or meta.get("model"),
92+
)
8993

9094
async def call_async(query):
9195
loop = asyncio.get_event_loop()

src/config/app.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ def __init__(self):
5858
# 模型配置
5959
## 注意这里是模型名,而不是具体的模型路径,默认使用 HuggingFace 的路径
6060
## 如果需要自定义本地模型路径,则在 .env 中配置 MODEL_DIR
61-
self.add_item("model_provider", default="siliconflow", des="模型提供商", choices=list(self.model_names.keys()))
62-
self.add_item("model_name", default="zai-org/GLM-4.5", des="模型名称")
61+
self.add_item(
62+
"default_model",
63+
default=self._get_default_chat_model_spec(),
64+
des="默认对话模型",
65+
)
6366
self.add_item(
6467
"fast_model",
6568
default="siliconflow/THUDM/GLM-4-9B-0414",
@@ -81,6 +84,9 @@ def __init__(self):
8184
### <<< 默认配置结束
8285

8386
self.load()
87+
# 清理已废弃的配置项
88+
self.pop("model_provider", None)
89+
self.pop("model_name", None)
8490
self.handle_self()
8591

8692
def add_item(self, key, default, des=None, choices=None):
@@ -137,6 +143,22 @@ def _save_models_to_file(self):
137143
with open(self._models_config_path, "w", encoding="utf-8") as f:
138144
yaml.safe_dump(models_payload, f, indent=2, allow_unicode=True, sort_keys=False)
139145

146+
def _get_default_chat_model_spec(self):
147+
"""选择一个默认的聊天模型,优先使用 siliconflow 的默认模型"""
148+
preferred_provider = "siliconflow"
149+
provider_info = (self.model_names or {}).get(preferred_provider)
150+
if provider_info:
151+
default_model = provider_info.get("default")
152+
if default_model:
153+
return f"{preferred_provider}/{default_model}"
154+
155+
for provider, info in (self.model_names or {}).items():
156+
default_model = info.get("default")
157+
if default_model:
158+
return f"{provider}/{default_model}"
159+
160+
return ""
161+
140162
def handle_self(self):
141163
"""
142164
处理配置

src/knowledge/implementations/lightrag.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,19 @@ def _get_llm_func(self, llm_info: dict):
174174
from src.models import select_model
175175

176176
# 如果用户选择了LLM,使用用户选择的;否则使用环境变量默认值
177-
if llm_info and llm_info.get("provider") and llm_info.get("model_name"):
178-
provider = llm_info["provider"]
179-
model_name = llm_info["model_name"]
177+
if llm_info and llm_info.get("model_spec"):
178+
model_spec = llm_info["model_spec"]
179+
logger.info(f"Using user-selected LLM spec: {model_spec}")
180+
elif llm_info and llm_info.get("provider") and llm_info.get("model_name"):
181+
model_spec = f"{llm_info['provider']}/{llm_info['model_name']}"
180182
logger.info(f"Using user-selected LLM: {provider}/{model_name}")
181183
else:
182184
provider = LIGHTRAG_LLM_PROVIDER
183185
model_name = LIGHTRAG_LLM_NAME
186+
model_spec = f"{provider}/{model_name}"
184187
logger.info(f"Using default LLM from environment: {provider}/{model_name}")
185188

186-
model = select_model(provider, model_name)
189+
model = select_model(model_spec=model_spec)
187190

188191
async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
189192
return await openai_complete_if_cache(

src/models/chat.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88
from src.utils import logger
99

1010

11+
def split_model_spec(model_spec, sep="/"):
12+
"""
13+
将 provider/model 形式的字符串拆分为 (provider, model)
14+
"""
15+
if not model_spec or not isinstance(model_spec, str):
16+
return "", ""
17+
if not sep:
18+
return model_spec, ""
19+
try:
20+
provider, model_name = model_spec.split(sep, 1)
21+
return provider, model_name
22+
except ValueError:
23+
return model_spec, ""
24+
25+
1126
class OpenAIBase:
1227
def __init__(self, api_key, base_url, model_name, **kwargs):
1328
self.api_key = api_key
@@ -85,12 +100,26 @@ def __init__(self, content):
85100
self.is_full = False
86101

87102

88-
def select_model(model_provider, model_name=None):
103+
def select_model(model_provider=None, model_name=None, model_spec=None):
89104
"""根据模型提供者选择模型"""
90-
assert model_provider is not None, "Model provider not specified"
105+
if model_spec:
106+
spec_provider, spec_model_name = split_model_spec(model_spec)
107+
model_provider = model_provider or spec_provider
108+
model_name = model_name or spec_model_name
109+
110+
if model_provider is None or not model_name:
111+
default_provider, default_model = split_model_spec(getattr(config, "default_model", ""))
112+
model_provider = model_provider or default_provider
113+
model_name = model_name or default_model
114+
115+
assert model_provider, "Model provider not specified"
116+
91117
model_info = config.model_names.get(model_provider, {})
92118
model_name = model_name or model_info.get("default", "")
93119

120+
if not model_name:
121+
raise ValueError(f"Model name not specified for provider {model_provider}")
122+
94123
logger.info(f"Selecting model from `{model_provider}` with `{model_name}`")
95124

96125
if model_provider == "openai":

test/bruteforce_simulation.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def parse_args() -> argparse.Namespace:
2323
parser = argparse.ArgumentParser(description="Simulate brute-force login attempts.")
2424
parser.add_argument("--base-url", default=os.getenv("TEST_BASE_URL", "http://localhost:5050"), help="API base URL")
2525
parser.add_argument("--username", default=os.getenv("TEST_USERNAME", "admin"), help="Login identifier to attack")
26-
parser.add_argument(
27-
"--attempts", type=int, default=20, help="Total number of attempts to issue (default: 20)"
28-
)
26+
parser.add_argument("--attempts", type=int, default=20, help="Total number of attempts to issue (default: 20)")
2927
parser.add_argument(
3028
"--concurrency",
3129
type=int,
@@ -68,10 +66,13 @@ async def attempt_login(
6866
started = time.perf_counter()
6967
response = await client.post("/api/auth/token", data=payload)
7068
elapsed = time.perf_counter() - started
71-
detail = response.json().get("detail") if response.headers.get("content-type", "").startswith("application/json") else response.text
69+
detail = (
70+
response.json().get("detail")
71+
if response.headers.get("content-type", "").startswith("application/json")
72+
else response.text
73+
)
7274
print(
73-
f"[{attempt_no:02d}] {response.status_code} in {elapsed*1000:.1f} ms "
74-
f"(pwd={password!r}) detail={detail!r}"
75+
f"[{attempt_no:02d}] {response.status_code} in {elapsed * 1000:.1f} ms (pwd={password!r}) detail={detail!r}"
7576
)
7677
return response.status_code, elapsed
7778

@@ -86,9 +87,7 @@ async def run_simulation(args: argparse.Namespace) -> int:
8687
tasks = []
8788
for attempt_no in range(1, args.attempts + 1):
8889
tasks.append(
89-
asyncio.create_task(
90-
attempt_login(client, semaphore, attempt_no, args.username, args.password)
91-
)
90+
asyncio.create_task(attempt_login(client, semaphore, attempt_no, args.username, args.password))
9291
)
9392
if args.delay:
9493
await asyncio.sleep(args.delay)

web/src/components/AgentConfigSidebar.vue

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@
5656
<div v-if="value.template_metadata.kind === 'llm'" class="model-selector">
5757
<ModelSelectorComponent
5858
@select-model="handleModelChange"
59-
:model_name="agentConfig[key] ? agentConfig[key].split('/').slice(1).join('/') : ''"
60-
:model_provider="agentConfig[key] ? agentConfig[key].split('/')[0] : ''"
59+
:model_spec="agentConfig[key] || ''"
6160
/>
6261
</div>
6362

@@ -367,9 +366,10 @@ const getPlaceholder = (key, value) => {
367366
return `(默认: ${value.default}`;
368367
};
369368
370-
const handleModelChange = (data) => {
369+
const handleModelChange = (spec) => {
370+
if (typeof spec !== 'string' || !spec) return;
371371
agentStore.updateAgentConfig({
372-
model: `${data.provider}/${data.name}`
372+
model: spec
373373
});
374374
};
375375

web/src/components/ModelSelectorComponent.vue

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
<div class="model-select" @click.prevent>
44
<div class="model-select-content">
55
<div class="model-info">
6-
<a-tooltip :title="model_name" placement="right">
7-
<span class="model-text text"> {{ model_name }} </span>
6+
<a-tooltip :title="displayModelText" placement="right">
7+
<span class="model-text text"> {{ displayModelText }} </span>
88
</a-tooltip>
9-
<span class="model-provider">{{ model_provider }}</span>
9+
<span class="model-provider">{{ displayModelProvider }}</span>
1010
</div>
1111
<div class="model-status-controls">
1212
<span
@@ -48,13 +48,17 @@ import { useConfigStore } from '@/stores/config'
4848
import { chatModelApi } from '@/apis/system_api'
4949
5050
const props = defineProps({
51-
model_name: {
51+
model_spec: {
5252
type: String,
5353
default: ''
5454
},
55-
model_provider: {
55+
sep: {
5656
type: String,
57-
default: ''
57+
default: '/'
58+
},
59+
placeholder: {
60+
type: String,
61+
default: '请选择模型'
5862
}
5963
});
6064
@@ -76,25 +80,48 @@ const modelKeys = computed(() => {
7680
return Object.keys(modelStatus.value || {}).filter(key => modelStatus.value?.[key])
7781
})
7882
83+
const resolvedSep = computed(() => props.sep || '/')
84+
85+
const resolvedModel = computed(() => {
86+
const spec = props.model_spec || ''
87+
const sep = resolvedSep.value
88+
if (spec && sep) {
89+
const index = spec.indexOf(sep)
90+
if (index !== -1) {
91+
const provider = spec.slice(0, index)
92+
const name = spec.slice(index + sep.length)
93+
if (provider && name) {
94+
return { provider, name }
95+
}
96+
}
97+
}
98+
return { provider: '', name: '' }
99+
})
100+
101+
const displayModelProvider = computed(() => resolvedModel.value.provider || '')
102+
const displayModelName = computed(() => resolvedModel.value.name || '')
103+
const displayModelText = computed(() => displayModelName.value || props.placeholder)
104+
79105
// 当前模型状态
80106
const currentModelStatus = computed(() => {
81107
return state.currentModelStatus
82108
})
83109
84110
// 检查当前模型状态
85111
const checkCurrentModelStatus = async () => {
86-
if (!props.model_provider || !props.model_name) return
112+
const { provider, name } = resolvedModel.value
113+
if (!provider || !name) return
87114
88115
try {
89116
state.checkingStatus = true
90-
const response = await chatModelApi.getModelStatus(props.model_provider, props.model_name)
117+
const response = await chatModelApi.getModelStatus(provider, name)
91118
if (response.status) {
92119
state.currentModelStatus = response.status
93120
} else {
94121
state.currentModelStatus = null
95122
}
96123
} catch (error) {
97-
console.error(`检查当前模型 ${props.model_provider}/${props.model_name} 状态失败:`, error)
124+
console.error(`检查当前模型 ${provider}/${name} 状态失败:`, error)
98125
state.currentModelStatus = { status: 'error', message: error.message }
99126
} finally {
100127
state.checkingStatus = false
@@ -128,24 +155,27 @@ const getCurrentModelStatusTooltip = () => {
128155
129156
// 选择模型的方法
130157
const handleSelectModel = async (provider, name) => {
131-
emit('select-model', { provider, name })
158+
const sep = resolvedSep.value || '/'
159+
const separator = sep || '/'
160+
const spec = `${provider}${separator}${name}`
161+
emit('select-model', spec)
132162
}
133163
134164
</script>
135165
136166
<style lang="less" scoped>
137167
// 变量定义
138-
@status-success: #52c41a;
139-
@status-error: #ff4d4f;
140-
@status-warning: #faad14;
141-
@status-default: #999;
168+
@status-success: var(--color-success);
169+
@status-error: var(--color-error);
170+
@status-warning: var(--chart-warning);
171+
@status-default: var(--gray-500);
142172
@border-radius: 8px;
143173
@scrollbar-width: 6px;
144174
@status-indicator-padding: 2px 4px;
145175
@status-check-button-padding: 0 4px;
146176
@status-check-button-font-size: 12px;
147177
@status-indicator-font-size: 11px;
148-
@model-provider-color: #aaa;
178+
@model-provider-color: var(--gray-500);
149179
150180
// 主选择器样式
151181
.model-select {
@@ -156,11 +186,12 @@ const handleSelectModel = async (provider, name) => {
156186
cursor: pointer;
157187
border: 1px solid var(--gray-200);
158188
border-radius: @border-radius;
159-
background-color: white;
189+
background-color: var(--gray-0);
160190
min-width: 0;
161191
display: flex;
162192
align-items: center;
163193
gap: 0.5rem;
194+
font-size: 13px;
164195
165196
// 修饰符类
166197
&.borderless {
@@ -188,7 +219,7 @@ const handleSelectModel = async (provider, name) => {
188219
.model-text {
189220
overflow: hidden;
190221
text-overflow: ellipsis;
191-
color: #000;
222+
color: var(--gray-1000);
192223
white-space: nowrap;
193224
}
194225
@@ -270,4 +301,4 @@ const handleSelectModel = async (provider, name) => {
270301
overflow-y: auto;
271302
}
272303
}
273-
</style>
304+
</style>

0 commit comments

Comments
 (0)