forked from YILING0013/AI_NovelGenerator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig_manager.py
More file actions
80 lines (71 loc) · 3 KB
/
config_manager.py
File metadata and controls
80 lines (71 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# config_manager.py
# -*- coding: utf-8 -*-
import json
import os
import threading
from llm_adapters import create_llm_adapter
from embedding_adapters import create_embedding_adapter
def load_config(config_file: str) -> dict:
"""从指定的 config_file 加载配置,若不存在则返回空字典。"""
if os.path.exists(config_file):
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except:
pass
return {}
def save_config(config_data: dict, config_file: str) -> bool:
"""将 config_data 保存到 config_file 中,返回 True/False 表示是否成功。"""
try:
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config_data, f, ensure_ascii=False, indent=4)
return True
except:
return False
def test_llm_config(interface_format, api_key, base_url, model_name, temperature, max_tokens, timeout, log_func, handle_exception_func):
"""测试当前的LLM配置是否可用"""
def task():
try:
log_func("开始测试LLM配置...")
llm_adapter = create_llm_adapter(
interface_format=interface_format,
base_url=base_url,
model_name=model_name,
api_key=api_key,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout
)
test_prompt = "Please reply 'OK'"
response = llm_adapter.invoke(test_prompt)
if response:
log_func("✅ LLM配置测试成功!")
log_func(f"测试回复: {response}")
else:
log_func("❌ LLM配置测试失败:未获取到响应")
except Exception as e:
log_func(f"❌ LLM配置测试出错: {str(e)}")
handle_exception_func("测试LLM配置时出错")
threading.Thread(target=task, daemon=True).start()
def test_embedding_config(api_key, base_url, interface_format, model_name, log_func, handle_exception_func):
"""测试当前的Embedding配置是否可用"""
def task():
try:
log_func("开始测试Embedding配置...")
embedding_adapter = create_embedding_adapter(
interface_format=interface_format,
api_key=api_key,
base_url=base_url,
model_name=model_name
)
test_text = "测试文本"
embeddings = embedding_adapter.embed_query(test_text)
if embeddings and len(embeddings) > 0:
log_func("✅ Embedding配置测试成功!")
log_func(f"生成的向量维度: {len(embeddings)}")
else:
log_func("❌ Embedding配置测试失败:未获取到向量")
except Exception as e:
log_func(f"❌ Embedding配置测试出错: {str(e)}")
handle_exception_func("测试Embedding配置时出错")
threading.Thread(target=task, daemon=True).start()