@@ -42,7 +42,7 @@ async def query_nodes(self, keyword: str, **kwargs) -> dict[str, Any]:
4242 try :
4343 with self ._db .driver .session () as session :
4444 result = session .run (query , keyword = keyword , kb_id = kb_id , limit = limit )
45- return self ._process_query_result (result )
45+ return self ._process_query_result (result , limit = limit )
4646 except Exception as e :
4747 logger .error (f"Neo4j query failed: { e } " )
4848 return {"nodes" : [], "edges" : []}
@@ -280,21 +280,28 @@ def _build_subgraph_query(self, limit: int, kb_id: str = None) -> str:
280280
281281 return query
282282
283- def _process_query_result (self , result ) -> dict [str , list ]:
284- """处理查询结果"""
283+ def _process_query_result (self , result , limit : int = None ) -> dict [str , list ]:
284+ """处理查询结果,并限制节点数量不超过 limit """
285285 nodes = []
286286 edges = []
287287 node_ids = set ()
288288 edge_ids = set ()
289289
290290 for record in result :
291+ # 检查是否已达到节点限制
292+ if limit is not None and len (node_ids ) >= limit :
293+ break
294+
291295 for key in record .keys ():
292296 val = record [key ]
293297 if val is None :
294298 continue
295299
296300 if hasattr (val , "element_id" ) and hasattr (val , "labels" ): # Node
297301 if val .element_id not in node_ids :
302+ # 再次检查限制
303+ if limit is not None and len (node_ids ) >= limit :
304+ break
298305 nodes .append (self .normalize_node (val ))
299306 node_ids .add (val .element_id )
300307 elif hasattr (val , "element_id" ) and hasattr (val , "start_node" ): # Relationship
@@ -305,6 +312,8 @@ def _process_query_result(self, result) -> dict[str, list]:
305312 for item in val :
306313 if hasattr (item , "element_id" ) and hasattr (item , "labels" ):
307314 if item .element_id not in node_ids :
315+ if limit is not None and len (node_ids ) >= limit :
316+ break
308317 nodes .append (self .normalize_node (item ))
309318 node_ids .add (item .element_id )
310319
0 commit comments