Skip to content

Commit 4eb2c49

Browse files
committed
fix(lightrag): 修复查询结果节点数量未限制的问题
1 parent cd34ea5 commit 4eb2c49

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/knowledge/adapters/lightrag.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def query_nodes(self, keyword: str, **kwargs) -> dict[str, Any]:
4242
try:
4343
with self._db.driver.session() as session:
4444
result = session.run(query, keyword=keyword, kb_id=kb_id, limit=limit)
45-
return self._process_query_result(result)
45+
return self._process_query_result(result, limit=limit)
4646
except Exception as e:
4747
logger.error(f"Neo4j query failed: {e}")
4848
return {"nodes": [], "edges": []}
@@ -280,21 +280,28 @@ def _build_subgraph_query(self, limit: int, kb_id: str = None) -> str:
280280

281281
return query
282282

283-
def _process_query_result(self, result) -> dict[str, list]:
284-
"""处理查询结果"""
283+
def _process_query_result(self, result, limit: int = None) -> dict[str, list]:
284+
"""处理查询结果,并限制节点数量不超过 limit"""
285285
nodes = []
286286
edges = []
287287
node_ids = set()
288288
edge_ids = set()
289289

290290
for record in result:
291+
# 检查是否已达到节点限制
292+
if limit is not None and len(node_ids) >= limit:
293+
break
294+
291295
for key in record.keys():
292296
val = record[key]
293297
if val is None:
294298
continue
295299

296300
if hasattr(val, "element_id") and hasattr(val, "labels"): # Node
297301
if val.element_id not in node_ids:
302+
# 再次检查限制
303+
if limit is not None and len(node_ids) >= limit:
304+
break
298305
nodes.append(self.normalize_node(val))
299306
node_ids.add(val.element_id)
300307
elif hasattr(val, "element_id") and hasattr(val, "start_node"): # Relationship
@@ -305,6 +312,8 @@ def _process_query_result(self, result) -> dict[str, list]:
305312
for item in val:
306313
if hasattr(item, "element_id") and hasattr(item, "labels"):
307314
if item.element_id not in node_ids:
315+
if limit is not None and len(node_ids) >= limit:
316+
break
308317
nodes.append(self.normalize_node(item))
309318
node_ids.add(item.element_id)
310319

0 commit comments

Comments
 (0)