Skip to content

Commit 14d95ce

Browse files
authored
feat: 知识库更新,支持Milvus、Chroma、Lightrag类型的知识库 (#223)
1 parent f463f80 commit 14d95ce

17 files changed

+2981
-593
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ readme = "README.md"
66
requires-python = ">=3.11"
77
dependencies = [
88
"asyncpg>=0.30.0",
9+
"chromadb>=1.0.15",
910
"colorlog>=6.9.0",
1011
"dashscope>=1.23.2",
1112
"docx2txt>=0.9",

server/routers/data_router.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,194 @@ async def get_databases(current_user: User = Depends(get_admin_user)):
2020
return {"message": f"获取数据库列表失败 {e}", "databases": []}
2121
return database
2222

23+
@data.get("/kb-types")
24+
async def get_knowledge_base_types(current_user: User = Depends(get_admin_user)):
25+
"""获取支持的知识库类型"""
26+
try:
27+
kb_types = knowledge_base.get_supported_kb_types()
28+
return {"kb_types": kb_types, "message": "success"}
29+
except Exception as e:
30+
logger.error(f"获取知识库类型失败 {e}, {traceback.format_exc()}")
31+
return {"message": f"获取知识库类型失败 {e}", "kb_types": {}}
32+
33+
@data.get("/stats")
34+
async def get_knowledge_base_statistics(current_user: User = Depends(get_admin_user)):
35+
"""获取知识库统计信息"""
36+
try:
37+
stats = knowledge_base.get_statistics()
38+
return {"stats": stats, "message": "success"}
39+
except Exception as e:
40+
logger.error(f"获取知识库统计失败 {e}, {traceback.format_exc()}")
41+
return {"message": f"获取知识库统计失败 {e}", "stats": {}}
42+
43+
@data.get("/query-params/{db_id}")
44+
async def get_knowledge_base_query_params(db_id: str, current_user: User = Depends(get_admin_user)):
45+
"""获取知识库类型特定的查询参数"""
46+
try:
47+
# 获取数据库信息
48+
db_info = knowledge_base.get_database_info(db_id)
49+
if not db_info:
50+
raise HTTPException(status_code=404, detail="Database not found")
51+
52+
kb_type = db_info.get("kb_type", "lightrag")
53+
54+
# 根据知识库类型返回不同的查询参数
55+
if kb_type == "lightrag":
56+
params = {
57+
"type": "lightrag",
58+
"options": [
59+
{
60+
"key": "mode",
61+
"label": "检索模式",
62+
"type": "select",
63+
"default": "mix",
64+
"options": [
65+
{"value": "local", "label": "Local", "description": "上下文相关信息"},
66+
{"value": "global", "label": "Global", "description": "全局知识"},
67+
{"value": "hybrid", "label": "Hybrid", "description": "本地和全局混合"},
68+
{"value": "naive", "label": "Naive", "description": "基本搜索"},
69+
{"value": "mix", "label": "Mix", "description": "知识图谱和向量检索混合"},
70+
]
71+
},
72+
{
73+
"key": "only_need_context",
74+
"label": "只使用上下文",
75+
"type": "boolean",
76+
"default": True,
77+
"description": "只返回上下文,不生成回答"
78+
},
79+
{
80+
"key": "only_need_prompt",
81+
"label": "只使用提示",
82+
"type": "boolean",
83+
"default": False,
84+
"description": "只返回提示,不进行检索"
85+
},
86+
{
87+
"key": "top_k",
88+
"label": "TopK",
89+
"type": "number",
90+
"default": 10,
91+
"min": 1,
92+
"max": 100,
93+
"description": "返回的最大结果数量"
94+
}
95+
]
96+
}
97+
elif kb_type == "chroma":
98+
params = {
99+
"type": "chroma",
100+
"options": [
101+
{
102+
"key": "top_k",
103+
"label": "TopK",
104+
"type": "number",
105+
"default": 10,
106+
"min": 1,
107+
"max": 100,
108+
"description": "返回的最大结果数量"
109+
},
110+
{
111+
"key": "similarity_threshold",
112+
"label": "相似度阈值",
113+
"type": "number",
114+
"default": 0.0,
115+
"min": 0.0,
116+
"max": 1.0,
117+
"step": 0.1,
118+
"description": "过滤相似度低于此值的结果"
119+
},
120+
{
121+
"key": "include_distances",
122+
"label": "显示相似度",
123+
"type": "boolean",
124+
"default": True,
125+
"description": "在结果中显示相似度分数"
126+
}
127+
]
128+
}
129+
elif kb_type == "milvus":
130+
params = {
131+
"type": "milvus",
132+
"options": [
133+
{
134+
"key": "top_k",
135+
"label": "TopK",
136+
"type": "number",
137+
"default": 10,
138+
"min": 1,
139+
"max": 100,
140+
"description": "返回的最大结果数量"
141+
},
142+
{
143+
"key": "similarity_threshold",
144+
"label": "相似度阈值",
145+
"type": "number",
146+
"default": 0.0,
147+
"min": 0.0,
148+
"max": 1.0,
149+
"step": 0.1,
150+
"description": "过滤相似度低于此值的结果"
151+
},
152+
{
153+
"key": "include_distances",
154+
"label": "显示相似度",
155+
"type": "boolean",
156+
"default": True,
157+
"description": "在结果中显示相似度分数"
158+
},
159+
{
160+
"key": "metric_type",
161+
"label": "距离度量类型",
162+
"type": "select",
163+
"default": "COSINE",
164+
"options": [
165+
{"value": "COSINE", "label": "余弦相似度", "description": "适合文本语义相似度"},
166+
{"value": "L2", "label": "欧几里得距离", "description": "适合数值型数据"},
167+
{"value": "IP", "label": "内积", "description": "适合标准化向量"}
168+
],
169+
"description": "向量相似度计算方法"
170+
}
171+
]
172+
}
173+
else:
174+
# 未知类型,返回基本参数
175+
params = {
176+
"type": "unknown",
177+
"options": [
178+
{
179+
"key": "top_k",
180+
"label": "TopK",
181+
"type": "number",
182+
"default": 10,
183+
"min": 1,
184+
"max": 100,
185+
"description": "返回的最大结果数量"
186+
}
187+
]
188+
}
189+
190+
return {"params": params, "message": "success"}
191+
192+
except Exception as e:
193+
logger.error(f"获取知识库查询参数失败 {e}, {traceback.format_exc()}")
194+
return {"message": f"获取知识库查询参数失败 {e}", "params": {}}
195+
23196
@data.post("/")
24197
async def create_database(
25198
database_name: str = Body(...),
26199
description: str = Body(...),
27200
embed_model_name: str = Body(...),
201+
kb_type: str = Body("lightrag"), # 新增:知识库类型参数,默认为lightrag
28202
current_user: User = Depends(get_admin_user)
29203
):
30-
logger.debug(f"Create database {database_name}")
204+
logger.debug(f"Create database {database_name} with kb_type {kb_type}")
31205
try:
32206
embed_info = config.embed_model_names[embed_model_name]
33207
database_info = knowledge_base.create_database(
34208
database_name,
35209
description,
210+
kb_type=kb_type, # 传递知识库类型
36211
embed_info=embed_info
37212
)
38213
except Exception as e:

server/routers/graph_router.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,17 @@ async def get_subgraph(
3232
try:
3333
logger.info(f"获取子图数据 - db_id: {db_id}, node_label: {node_label}, max_depth: {max_depth}, max_nodes: {max_nodes}")
3434

35+
# 检查是否是 LightRAG 数据库
36+
if not knowledge_base.is_lightrag_database(db_id):
37+
raise HTTPException(
38+
status_code=400,
39+
detail=f"数据库 {db_id} 不是 LightRAG 类型,图谱功能仅支持 LightRAG 知识库"
40+
)
41+
3542
# 获取 LightRAG 实例
3643
rag_instance = await knowledge_base._get_lightrag_instance(db_id)
3744
if not rag_instance:
38-
raise HTTPException(status_code=404, detail=f"数据库 {db_id} 不存在")
45+
raise HTTPException(status_code=404, detail=f"LightRAG 数据库 {db_id} 不存在或无法访问")
3946

4047
# 使用 LightRAG 的原生 get_knowledge_graph 方法
4148
knowledge_graph = await rag_instance.get_knowledge_graph(
@@ -78,6 +85,9 @@ async def get_subgraph(
7885
logger.info(f"成功获取子图 - 节点数: {len(nodes)}, 边数: {len(edges)}")
7986
return result
8087

88+
except HTTPException:
89+
# 重新抛出 HTTP 异常
90+
raise
8191
except Exception as e:
8292
logger.error(f"获取子图数据失败: {e}")
8393
logger.error(f"Traceback: {traceback.format_exc()}")
@@ -101,10 +111,17 @@ async def get_graph_labels(
101111
try:
102112
logger.info(f"获取图谱标签 - db_id: {db_id}")
103113

114+
# 检查是否是 LightRAG 数据库
115+
if not knowledge_base.is_lightrag_database(db_id):
116+
raise HTTPException(
117+
status_code=400,
118+
detail=f"数据库 {db_id} 不是 LightRAG 类型,图谱功能仅支持 LightRAG 知识库"
119+
)
120+
104121
# 获取 LightRAG 实例
105122
rag_instance = await knowledge_base._get_lightrag_instance(db_id)
106123
if not rag_instance:
107-
raise HTTPException(status_code=404, detail=f"数据库 {db_id} 不存在")
124+
raise HTTPException(status_code=404, detail=f"LightRAG 数据库 {db_id} 不存在或无法访问")
108125

109126
# 使用 LightRAG 的原生方法获取所有标签
110127
labels = await rag_instance.get_graph_labels()
@@ -116,6 +133,9 @@ async def get_graph_labels(
116133
}
117134
}
118135

136+
except HTTPException:
137+
# 重新抛出 HTTP 异常
138+
raise
119139
except Exception as e:
120140
logger.error(f"获取图谱标签失败: {e}")
121141
logger.error(f"Traceback: {traceback.format_exc()}")
@@ -130,18 +150,20 @@ async def get_available_databases(
130150
获取所有可用的 LightRAG 数据库
131151
132152
Returns:
133-
可用的数据库列表
153+
可用的 LightRAG 数据库列表
134154
"""
135155
try:
136-
databases = knowledge_base.get_databases()
156+
lightrag_databases = knowledge_base.get_lightrag_databases()
137157
return {
138158
"success": True,
139-
"data": databases
159+
"data": {
160+
"databases": lightrag_databases
161+
}
140162
}
141163

142164
except Exception as e:
143-
logger.error(f"获取数据库列表失败: {e}")
144-
raise HTTPException(status_code=500, detail=f"获取数据库列表失败: {str(e)}")
165+
logger.error(f"获取 LightRAG 数据库列表失败: {e}")
166+
raise HTTPException(status_code=500, detail=f"获取 LightRAG 数据库列表失败: {str(e)}")
145167

146168

147169
# 保留原有的直接数据库查询方法作为备用(如果需要的话)
@@ -215,10 +237,17 @@ async def get_graph_stats(
215237
try:
216238
logger.info(f"获取图谱统计信息 - db_id: {db_id}")
217239

240+
# 检查是否是 LightRAG 数据库
241+
if not knowledge_base.is_lightrag_database(db_id):
242+
raise HTTPException(
243+
status_code=400,
244+
detail=f"数据库 {db_id} 不是 LightRAG 类型,图谱功能仅支持 LightRAG 知识库"
245+
)
246+
218247
# 获取 LightRAG 实例
219248
rag_instance = await knowledge_base._get_lightrag_instance(db_id)
220249
if not rag_instance:
221-
raise HTTPException(status_code=404, detail=f"数据库 {db_id} 不存在")
250+
raise HTTPException(status_code=404, detail=f"LightRAG 数据库 {db_id} 不存在或无法访问")
222251

223252
# 通过获取全图来统计节点和边的数量
224253
knowledge_graph = await rag_instance.get_knowledge_graph(
@@ -248,6 +277,9 @@ async def get_graph_stats(
248277
}
249278
}
250279

280+
except HTTPException:
281+
# 重新抛出 HTTP 异常
282+
raise
251283
except Exception as e:
252284
logger.error(f"获取图谱统计信息失败: {e}")
253285
logger.error(f"Traceback: {traceback.format_exc()}")

src/agents/tools_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def get_all_tools():
3232
)
3333

3434
# 创建异步工具,确保正确处理异步检索器
35-
async def async_retriever_wrapper(query_text: str, db_id=db_Id):
35+
async def async_retriever_wrapper(query_text: str, db_id=db_Id, retriever_info=retrieve_info):
3636
"""异步检索器包装函数"""
37-
retriever = retrieve_info["retriever"]
37+
retriever = retriever_info["retriever"]
3838
try:
3939
if asyncio.iscoroutinefunction(retriever):
4040
result = await retriever(query_text)

src/core/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from .history import HistoryManager
2-
from .lightrag_based_kb import LightRagBasedKB
32
from .graphbase import GraphDatabase

0 commit comments

Comments
 (0)