Skip to content

Commit 962fbfc

Browse files
authored
Merge pull request #19 from gitveg/main
feat: add search embeddings in Milvus
2 parents c601195 + d7a82fd commit 962fbfc

File tree

6 files changed

+227
-23
lines changed

6 files changed

+227
-23
lines changed

rfcs/assets/search_docs_false.png

246 KB
Loading

rfcs/assets/search_docs_true.png

293 KB
Loading

rfcs/notes/kjn-notes-2025.1.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# 1月工作记录
2+
3+
本月主要工作集中在编写、测试向量数据库的检索逻辑代码
4+
5+
## milvus检索代码
6+
7+
`milvus` 向量数据库提供了两种检索的方法,分别是 query 和 search 。
8+
9+
其中,search 方法主要用于执行近似最近邻搜索(Approximate Nearest Neighbors, ANN),即根据给定的查询向量找到与之最相似的向量。它的核心功能是基于**向量相似性**进行检索。
10+
11+
query 方法用于执行更广泛的基于条件的查询,主要用于基于条件的过滤,根据指定的条件表达式检索数据
12+
13+
在代码编写上选择使用更适合 RAG 系统的 search 方法。
14+
15+
检索逻辑代码放在了 milvus_client.py 下:
16+
17+
```python
18+
@get_time
19+
def search_docs(self, query_embedding: List[float] = None, filter_expr: str = None, doc_limit: int = 10):
20+
"""
21+
从 Milvus 集合中检索文档。
22+
23+
Args:
24+
query_embedding (List[float]): 查询向量,用于基于向量相似性检索。
25+
filter_expr (str): 过滤条件表达式,用于基于字段值的过滤。如"user_id == 'abc1234'"
26+
limit (int): 返回的文档数量上限,默认为 10。
27+
28+
Returns:
29+
List[dict]: 检索到的文档列表,每个文档是一个字典,包含字段值和向量。
30+
"""
31+
try:
32+
if not self.sess:
33+
raise MilvusFailed("Milvus collection is not loaded. Call load_collection_() first.")
34+
35+
# 构造查询参数
36+
search_params = {
37+
"metric_type": self.search_params["metric_type"],
38+
"params": self.search_params["params"]
39+
}
40+
41+
# 构造查询表达式
42+
expr = ""
43+
if filter_expr:
44+
expr = filter_expr
45+
46+
# 构造检索参数
47+
search_params.update({
48+
"data": [query_embedding] if query_embedding else None,
49+
"anns_field": "embedding", # 指定集合中存储向量的字段名称。Milvus 会在该字段上进行向量相似性检索。
50+
"param": {"metric_type": "L2", "params": {"nprobe": 128}}, # 检索的精度和性能
51+
"limit": doc_limit, # 指定返回的最相似文档的数量上限
52+
"expr": expr,
53+
"output_fields": self.output_fields
54+
})
55+
56+
# 执行检索
57+
results = self.sess.search(**search_params)
58+
59+
# 处理检索结果
60+
retrieved_docs = []
61+
for hits in results:
62+
for hit in hits:
63+
doc = {
64+
# "id": hit.id,
65+
# "distance": hit.distance,
66+
"user_id": hit.entity.get("user_id"),
67+
"kb_id": hit.entity.get("kb_id"),
68+
"file_id": hit.entity.get("file_id"),
69+
"headers": json.loads(hit.entity.get("headers")),
70+
"doc_id": hit.entity.get("doc_id"),
71+
"content": hit.entity.get("content"),
72+
"embedding": hit.entity.get("embedding")
73+
}
74+
retrieved_docs.append(doc)
75+
76+
return retrieved_docs
77+
78+
except Exception as e:
79+
print(f'[{cur_func_name()}] [search_docs] Failed to search documents: {traceback.format_exc()}')
80+
raise MilvusFailed(f"Failed to search documents: {str(e)}")
81+
```
82+
83+
## 测试milvus检索逻辑
84+
85+
利用已有的 embedding 文件夹下的 embedding_client.py(原名为 client.py )中的embedding处理代码,同时编写了 embed_user_input 方便测试。
86+
87+
同时在 milvus_client.py 的 main 函数中调用 search_docs 函数进行测试,测试结果如下。
88+
89+
不设置过滤条件正常检索:
90+
91+
![search_true](/rfcs/assets/search_docs_true.png)
92+
93+
设置过滤条件,检索结果为空:
94+
95+
![search_false](/rfcs/assets/search_docs_false.png)
96+
97+
98+
## 未来工作
99+
100+
后续继续实现 server 与 client 的交互处理,方便更好地测试用户的输入经过 embedding 后到 milvus 中进行检索的过程。
101+
102+
RAG 系统的 UI 界面逐步完善。

src/client/database/milvus/milvus_client.py

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from src.utils.general_utils import get_time, cur_func_name
2323
from src.configs.configs import MILVUS_HOST_LOCAL, MILVUS_PORT, VECTOR_SEARCH_TOP_K
2424

25+
from src.client.embedding.embedding_client import SBIEmbeddings, _process_query, embed_user_input
26+
2527

2628
class MilvusFailed(Exception):
2729
"""异常基类"""
@@ -107,6 +109,70 @@ def store_doc(self, doc: Document, embedding: List[float]):
107109
print(f'[{cur_func_name()}] [store_doc] Failed to store document: {traceback.format_exc()}')
108110
raise MilvusFailed(f"Failed to store document: {str(e)}")
109111

112+
@get_time
113+
def search_docs(self, query_embedding: List[float] = None, filter_expr: str = None, doc_limit: int = 10):
114+
"""
115+
从 Milvus 集合中检索文档。
116+
117+
Args:
118+
query_embedding (List[float]): 查询向量,用于基于向量相似性检索。
119+
filter_expr (str): 过滤条件表达式,用于基于字段值的过滤。如"user_id == 'abc1234'"
120+
limit (int): 返回的文档数量上限,默认为 10。
121+
122+
Returns:
123+
List[dict]: 检索到的文档列表,每个文档是一个字典,包含字段值和向量。
124+
"""
125+
try:
126+
if not self.sess:
127+
raise MilvusFailed("Milvus collection is not loaded. Call load_collection_() first.")
128+
129+
# 构造查询参数
130+
search_params = {
131+
"metric_type": self.search_params["metric_type"],
132+
"params": self.search_params["params"]
133+
}
134+
135+
# 构造查询表达式
136+
expr = ""
137+
if filter_expr:
138+
expr = filter_expr
139+
140+
# 构造检索参数
141+
search_params.update({
142+
"data": [query_embedding] if query_embedding else None,
143+
"anns_field": "embedding", # 指定集合中存储向量的字段名称。Milvus 会在该字段上进行向量相似性检索。
144+
"param": {"metric_type": "L2", "params": {"nprobe": 128}}, # 检索的精度和性能
145+
"limit": doc_limit, # 指定返回的最相似文档的数量上限
146+
"expr": expr,
147+
"output_fields": self.output_fields
148+
})
149+
150+
# 执行检索
151+
results = self.sess.search(**search_params)
152+
153+
# 处理检索结果
154+
retrieved_docs = []
155+
for hits in results:
156+
for hit in hits:
157+
doc = {
158+
# "id": hit.id,
159+
# "distance": hit.distance,
160+
"user_id": hit.entity.get("user_id"),
161+
"kb_id": hit.entity.get("kb_id"),
162+
"file_id": hit.entity.get("file_id"),
163+
"headers": json.loads(hit.entity.get("headers")),
164+
"doc_id": hit.entity.get("doc_id"),
165+
"content": hit.entity.get("content"),
166+
"embedding": hit.entity.get("embedding")
167+
}
168+
retrieved_docs.append(doc)
169+
170+
return retrieved_docs
171+
172+
except Exception as e:
173+
print(f'[{cur_func_name()}] [search_docs] Failed to search documents: {traceback.format_exc()}')
174+
raise MilvusFailed(f"Failed to search documents: {str(e)}")
175+
110176
@property
111177
def fields(self):
112178
fields = [
@@ -144,16 +210,17 @@ def main():
144210

145211
# 检索所有文档
146212
try:
147-
# 构造查询表达式(检索所有文档)
148-
query_expr = "" # 不设置过滤条件,检索所有文档
149-
150-
# 执行查询
151-
results = client.sess.query(
152-
expr=query_expr,
153-
output_fields=client.output_fields, # 指定返回的字段
154-
limit=1000
155-
)
213+
# # 构造查询表达式
214+
filter_expr = "123" # 设置过滤条件
156215

216+
# # 执行查询
217+
# results = client.sess.query(
218+
# expr=query_expr,
219+
# output_fields=client.output_fields, # 指定返回的字段
220+
# limit=1000
221+
# )
222+
query_expr = embed_user_input("荷塘月色")
223+
results = client.search_docs(query_expr, filter_expr, 1000)
157224
# 打印检索结果
158225
if not results:
159226
print(f"No documents found in collection {user_id}.")
@@ -165,7 +232,18 @@ def main():
165232
print(f" user_id: {result['user_id']}")
166233
print(f" kb_id: {result['kb_id']}")
167234
print(f" file_id: {result['file_id']}")
168-
print(f" headers: {json.loads(result['headers'])}") # 将 headers 从 JSON 字符串解析为字典
235+
# 检查 headers 的类型
236+
headers = result.get('headers')
237+
if isinstance(headers, dict):
238+
print(f" headers: {headers}")
239+
elif isinstance(headers, str):
240+
try:
241+
headers = json.loads(headers)
242+
print(f" headers: {headers}")
243+
except json.JSONDecodeError as e:
244+
print(f" headers: {headers} (无法解析为 JSON)")
245+
else:
246+
print(f" headers: {headers} (未知类型)")
169247
print(f" doc_id: {result['doc_id']}")
170248
print(f" content: {result['content']}")
171249
print(f" embedding: {result['embedding'][:5]}... (truncated)") # 只打印前 5 维向量
Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -180,23 +180,47 @@ async def performance_test():
180180
async_time = time.time() - start_time
181181
debug_logger.info(f"异步处理 {size} 个文本耗时: {async_time:.2f}秒")
182182

183+
def embed_user_input(user_input: str):
184+
"""测试用户输入的文本嵌入"""
185+
embedder = SBIEmbeddings()
186+
187+
# 对用户输入的文本进行预处理
188+
processed_input = _process_query(user_input)
189+
190+
debug_logger.info("\n测试用户输入的嵌入:")
191+
debug_logger.info(f"用户输入: {user_input}")
192+
debug_logger.info(f"预处理后的输入: {processed_input}")
193+
194+
try:
195+
# 使用同步方法获取嵌入向量
196+
embedding = embedder.embed_query(processed_input)
197+
debug_logger.info(f"嵌入向量维度: {len(embedding)}")
198+
debug_logger.info(f"嵌入向量: {embedding}")
199+
except Exception as e:
200+
debug_logger.error(f"嵌入过程中发生错误: {str(e)}")
201+
202+
return embedding
203+
183204

184205
async def main():
185206
"""主测试函数"""
186207
debug_logger.info(f"开始embedding客户端测试...")
187208

188-
# 测试异步方法
189-
await test_async_methods()
190-
191-
# # 测试同步方法
192-
# test_sync_methods()
193-
194-
# # 测试错误处理
195-
# test_error_handling()
196-
197-
# # 执行性能测试
198-
# await performance_test()
199-
209+
try:
210+
# 测试异步方法
211+
await test_async_methods()
212+
213+
# 测试同步方法
214+
test_sync_methods()
215+
216+
# 测试错误处理
217+
test_error_handling()
218+
219+
# 执行性能测试
220+
await performance_test()
221+
except Exception as e:
222+
debug_logger.error(f"测试过程中发生错误: {str(e)}")
223+
200224
debug_logger.info("embedding客户端测试完成")
201225

202226

src/server/api_server/sanic_api_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ async def upload_files(req: request):
198198
# 返回给前端的数据
199199
data.append({"file_id": file_id, "file_name": file_name, "status": "green",
200200
"bytes": len(local_file.file_content), "timestamp": timestamp, "estimated_chars": chars})
201-
# qanything 1.x版本处理方式,2.0以后的版本都是起另外一个服务轮训文件状态,之后添加到向量数据库中
201+
# qanything 1.x版本处理方式,2.0以后的版本都是起另外一个服务轮询文件状态,之后添加到向量数据库中
202202
# 后面做优化在像他们那样做,这样文件上传流程会快不少
203203
# asyncio.create_task(local_doc_qa.insert_files_to_milvus(user_id, kb_id, local_files))
204204
if failed_files:

0 commit comments

Comments
 (0)