Skip to content

Commit 610b60e

Browse files
committed
feat(knowledge): 智能体查询知识库时,支持基于文件名的模糊过滤功能,不支持 LightRAG
实现知识库检索时可按文件名进行模糊匹配过滤 在Milvus知识库中支持文件名的like表达式过滤 前端展示添加文件名显示 添加相关测试用例验证过滤功能
1 parent 6d9f3ad commit 610b60e

File tree

6 files changed

+179
-7
lines changed

6 files changed

+179
-7
lines changed

src/agents/common/tools.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ class KnowledgeRetrieverModel(BaseModel):
122122
)
123123

124124

125+
class CommonKnowledgeRetriever(KnowledgeRetrieverModel):
126+
"""Common knowledge retriever model."""
127+
file_name: str = Field(description="限定文件名称,当操作类型为 'search' 时,可以指定文件名称,支持模糊匹配")
128+
129+
125130
def get_kb_based_tools(db_names: list[str] | None = None) -> list:
126131
"""获取所有知识库基于的工具"""
127132
# 获取所有知识库
@@ -132,7 +137,9 @@ def get_kb_based_tools(db_names: list[str] | None = None) -> list:
132137
def _create_retriever_wrapper(db_id: str, retriever_info: dict[str, Any]):
133138
"""创建检索器包装函数的工厂函数,避免闭包变量捕获问题"""
134139

135-
async def async_retriever_wrapper(query_text: str, operation: str = "search") -> Any:
140+
async def async_retriever_wrapper(
141+
query_text: str, operation: str = "search", file_name: str | None = None
142+
) -> Any:
136143
"""异步检索器包装函数,支持检索和获取思维导图"""
137144

138145
# 获取思维导图
@@ -173,10 +180,14 @@ def mindmap_to_text(node, level=0):
173180
retriever = retriever_info["retriever"]
174181
try:
175182
logger.debug(f"Retrieving from database {db_id} with query: {query_text}")
183+
kwargs = {}
184+
if file_name:
185+
kwargs["file_name"] = file_name
186+
176187
if asyncio.iscoroutinefunction(retriever):
177-
result = await retriever(query_text)
188+
result = await retriever(query_text, **kwargs)
178189
else:
179-
result = retriever(query_text)
190+
result = retriever(query_text, **kwargs)
180191
logger.debug(f"Retrieved {len(result) if isinstance(result, list) else 'N/A'} results from {db_id}")
181192
return result
182193
except Exception as e:
@@ -207,12 +218,16 @@ def mindmap_to_text(node, level=0):
207218

208219
safename = retrieve_info["name"].replace(" ", "_")[:20]
209220

221+
args_schema = KnowledgeRetrieverModel
222+
if retrieve_info["metadata"]["kb_type"] in ["milvus"]:
223+
args_schema = CommonKnowledgeRetriever
224+
210225
# 使用 StructuredTool.from_function 创建异步工具
211226
tool = StructuredTool.from_function(
212227
coroutine=retriever_wrapper,
213228
name=safename,
214229
description=description,
215-
args_schema=KnowledgeRetrieverModel,
230+
args_schema=args_schema,
216231
metadata=retrieve_info["metadata"] | {"tag": ["knowledgebase"]},
217232
)
218233

src/knowledge/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,8 @@ def get_retrievers(self) -> dict[str, dict]:
549549
for db_id, meta in self.databases_meta.items():
550550

551551
def make_retriever(db_id):
552-
async def retriever(query_text):
553-
return await self.aquery(query_text, db_id)
552+
async def retriever(query_text, **kwargs):
553+
return await self.aquery(query_text, db_id, **kwargs)
554554

555555
return retriever
556556

src/knowledge/implementations/milvus.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,11 +458,26 @@ async def aquery(self, query_text: str, db_id: str, **kwargs) -> list[dict]:
458458
query_embedding = embedding_function([query_text])
459459

460460
search_params = {"metric_type": metric_type, "params": {"nprobe": 10}}
461+
462+
# 构建过滤表达式
463+
expr = None
464+
if file_name := kwargs.get("file_name"):
465+
# 使用 like 支持模糊匹配
466+
# 注意:需要转义双引号以防止注入
467+
safe_file_name = file_name.replace('"', '\\"')
468+
# 如果没有提供通配符,默认前后添加 %
469+
if "%" not in safe_file_name:
470+
expr = f'source like "%{safe_file_name}%"'
471+
else:
472+
expr = f'source like "{safe_file_name}"'
473+
logger.debug(f"Using filter expression: {expr}")
474+
461475
results = collection.search(
462476
data=query_embedding,
463477
anns_field="embedding",
464478
param=search_params,
465479
limit=recall_top_k,
480+
expr=expr,
466481
output_fields=["content", "source", "chunk_id", "file_id", "chunk_index"],
467482
)
468483

test/test_milvus_filter.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import asyncio
2+
import os
3+
import shutil
4+
from unittest.mock import MagicMock, patch
5+
6+
from src.knowledge import knowledge_base
7+
from src.utils import logger
8+
9+
# Mock Embedding Model
10+
class MockEmbeddingModel:
11+
async def abatch_encode(self, texts, batch_size=None):
12+
# Return dummy vectors of dim 4
13+
return [[0.1, 0.2, 0.3, 0.4] for _ in texts]
14+
15+
def batch_encode(self, texts, batch_size=None):
16+
return [[0.1, 0.2, 0.3, 0.4] for _ in texts]
17+
18+
# Test function
19+
async def test_milvus_filter():
20+
logger.info("Starting Milvus Filter Test")
21+
22+
# Check if Milvus is available (pymilvus installed and connection works)
23+
try:
24+
from pymilvus import connections, utility
25+
# Assuming Milvus is running at default location
26+
connections.connect(alias="default", uri=os.getenv("MILVUS_URI", "http://localhost:19530"))
27+
logger.info("Connected to Milvus")
28+
except Exception as e:
29+
logger.warning(f"Milvus not available or connection failed: {e}")
30+
# Proceeding might fail, but let's try.
31+
32+
db_id = "test_milvus_filter_db"
33+
file1 = "test_file_A.txt"
34+
file2 = "test_file_B.txt"
35+
36+
# Patch embedding model
37+
with patch("src.models.embed.select_embedding_model", return_value=MockEmbeddingModel()):
38+
39+
try:
40+
# Cleanup if exists
41+
if db_id in knowledge_base.global_databases_meta:
42+
await knowledge_base.delete_database(db_id)
43+
44+
# Create DB
45+
logger.info("Creating database...")
46+
# explicitly set dimension to 4 to match mock
47+
await knowledge_base.create_database(
48+
database_name="Test Milvus Filter",
49+
description="Test DB",
50+
kb_type="milvus",
51+
embed_info={"name": "mock-embedding", "dimension": 4, "model_id": "mock"}
52+
)
53+
54+
# Get actual db_id
55+
target_db = next((db for db in knowledge_base.get_databases()["databases"] if db["name"] == "Test Milvus Filter"), None)
56+
if not target_db:
57+
logger.error("Failed to create DB")
58+
return
59+
60+
db_id = target_db["db_id"]
61+
logger.info(f"DB created with ID: {db_id}")
62+
63+
# Create dummy files
64+
65+
with open(file1, "w") as f:
66+
f.write("Apple content.")
67+
with open(file2, "w") as f:
68+
f.write("Banana content.")
69+
70+
# Add content
71+
logger.info("Adding content...")
72+
await knowledge_base.add_content(db_id, [os.path.abspath(file1), os.path.abspath(file2)])
73+
74+
# Wait for data to be visible
75+
logger.info("Waiting for data to be visible...")
76+
await asyncio.sleep(2)
77+
78+
# Query without filter
79+
logger.info("Querying without filter...")
80+
results = await knowledge_base.aquery("content", db_id)
81+
logger.info(f"No filter results: {len(results)}")
82+
83+
# Verify we have chunks from both files
84+
sources = [r['metadata']['source'] for r in results]
85+
logger.info(f"Sources: {sources}")
86+
87+
# Query with filter A (Partial Match)
88+
logger.info("Querying with filter A (file_A)...")
89+
results_a = await knowledge_base.aquery("content", db_id, file_name="file_A")
90+
logger.info(f"Filter A results: {len(results_a)}")
91+
92+
if len(results_a) == 0:
93+
logger.error("FAIL: Filter A returned 0 results")
94+
95+
for r in results_a:
96+
source = r['metadata']['source']
97+
logger.info(f" - {source}")
98+
if "test_file_A.txt" not in source:
99+
logger.error(f"FAIL: Expected test_file_A.txt, got {source}")
100+
raise AssertionError("Filter A failed")
101+
102+
# Query with wildcard filter
103+
logger.info("Querying with wildcard filter (%B.txt)...")
104+
results_b = await knowledge_base.aquery("content", db_id, file_name="%B.txt")
105+
logger.info(f"Filter B results: {len(results_b)}")
106+
if len(results_b) == 0:
107+
logger.error("FAIL: Wildcard filter returned 0 results")
108+
109+
for r in results_b:
110+
source = r['metadata']['source']
111+
logger.info(f" - {source}")
112+
if "test_file_B.txt" not in source:
113+
logger.error(f"FAIL: Expected test_file_B.txt, got {source}")
114+
raise AssertionError("Filter B failed")
115+
116+
if len(results_a) > 0 and len(results_b) > 0:
117+
logger.info("Test passed!")
118+
else:
119+
logger.error("Test failed: No results found for one or more queries")
120+
121+
except Exception as e:
122+
logger.error(f"Test failed with exception: {e}")
123+
raise
124+
finally:
125+
# Cleanup
126+
logger.info("Cleaning up...")
127+
try:
128+
await knowledge_base.delete_database(db_id)
129+
except Exception:
130+
pass
131+
if os.path.exists(file1):
132+
os.remove(file1)
133+
if os.path.exists(file2):
134+
os.remove(file2)
135+
136+
if __name__ == "__main__":
137+
asyncio.run(test_milvus_filter())

web/src/components/ToolCallingResult/BaseToolCall.vue

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ const formatResultData = (data) => {
307307
text-overflow: ellipsis;
308308
white-space: nowrap;
309309
min-width: 0;
310-
flex: 1;
311310
}
312311
313312
:deep(.tag) {

web/src/components/ToolCallingResult/tools/KnowledgeBaseTool.vue

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
<span class="note">{{ operationLabel }}</span>
66
<span class="separator" v-if="queryText">|</span>
77
<span class="description">{{ queryText }}</span>
8+
<span class="separator" v-if="fileName">|</span>
9+
<span class="description" v-if="fileName">文件: {{ fileName }}</span>
810
</div>
911
</template>
1012
<template #result="{ resultContent }">
@@ -169,6 +171,10 @@ const queryText = computed(() => {
169171
return args.value.query_text || '';
170172
});
171173
174+
const fileName = computed(() => {
175+
return args.value.file_name || '';
176+
});
177+
172178
const parseData = (content) => {
173179
if (typeof content === 'string') {
174180
try {

0 commit comments

Comments
 (0)