@@ -64,6 +64,22 @@ def get_sample_nodes(self, kgdb_name="neo4j", num=50):
6464 assert self .driver is not None , "Database is not connected"
6565 self .use_database (kgdb_name )
6666
67+ def _process_record_props (record ):
68+ """处理记录中的属性:扁平化 properties 并移除 embedding"""
69+ if record is None :
70+ return None
71+
72+ # 复制一份以避免修改原字典
73+ data = dict (record )
74+ props = data .pop ("properties" , {}) or {}
75+
76+ # 移除 embedding
77+ if "embedding" in props :
78+ del props ["embedding" ]
79+
80+ # 合并属性(优先保留原字典中的 id, name, type 等核心字段)
81+ return {** props , ** data }
82+
6783 def query (tx , num ):
6884 """Note: 使用连通性查询获取集中的节点子图"""
6985 # 首先尝试获取一个连通的子图
@@ -104,17 +120,18 @@ def query(tx, num):
104120 OPTIONAL MATCH (n)-[rel]-(m)
105121 WHERE m IN final_nodes AND elementId(n) < elementId(m)
106122 RETURN
107- {id: elementId(n), name: n.name} AS h,
123+ {id: elementId(n), name: n.name, properties: properties(n) } AS h,
108124 CASE WHEN rel IS NOT NULL THEN
109125 {
110126 id: elementId(rel),
111127 type: rel.type,
112128 source_id: elementId(startNode(rel)),
113- target_id: elementId(endNode(rel))
129+ target_id: elementId(endNode(rel)),
130+ properties: properties(rel)
114131 }
115132 ELSE null END AS r,
116133 CASE WHEN m IS NOT NULL THEN
117- {id: elementId(m), name: m.name}
134+ {id: elementId(m), name: m.name, properties: properties(m) }
118135 ELSE null END AS t
119136 """
120137
@@ -124,7 +141,7 @@ def query(tx, num):
124141 node_ids = set ()
125142
126143 for item in results :
127- h_node = item ["h" ]
144+ h_node = _process_record_props ( item ["h" ])
128145
129146 # 始终添加头节点
130147 if h_node ["id" ] not in node_ids :
@@ -133,14 +150,15 @@ def query(tx, num):
133150
134151 # 只有当边和尾节点都存在时才处理
135152 if item ["r" ] is not None and item ["t" ] is not None :
136- t_node = item ["t" ]
153+ t_node = _process_record_props (item ["t" ])
154+ r_edge = _process_record_props (item ["r" ])
137155
138156 # 避免重复添加尾节点
139157 if t_node ["id" ] not in node_ids :
140158 formatted_results ["nodes" ].append (t_node )
141159 node_ids .add (t_node ["id" ])
142160
143- formatted_results ["edges" ].append (item [ "r" ] )
161+ formatted_results ["edges" ].append (r_edge )
144162
145163 # 如果连通查询返回的节点数不足,补充更多节点
146164 if len (formatted_results ["nodes" ]) < num :
@@ -150,14 +168,14 @@ def query(tx, num):
150168 supplement_query = """
151169 MATCH (n:Entity)
152170 WHERE NOT elementId(n) IN $existing_ids
153- RETURN {id: elementId(n), name: n.name} AS node
171+ RETURN {id: elementId(n), name: n.name, properties: properties(n) } AS node
154172 LIMIT $count
155173 """
156174
157175 supplement_results = tx .run (supplement_query , existing_ids = list (node_ids ), count = remaining_count )
158176
159177 for item in supplement_results :
160- node = item ["node" ]
178+ node = _process_record_props ( item ["node" ])
161179 formatted_results ["nodes" ].append (node )
162180 node_ids .add (node ["id" ])
163181
@@ -170,23 +188,25 @@ def query(tx, num):
170188 MATCH (n:Entity)-[r]-(m:Entity)
171189 WHERE elementId(n) < elementId(m)
172190 RETURN
173- {id: elementId(n), name: n.name} AS h,
191+ {id: elementId(n), name: n.name, properties: properties(n) } AS h,
174192 {
175193 id: elementId(r),
176194 type: r.type,
177195 source_id: elementId(startNode(r)),
178- target_id: elementId(endNode(r))
196+ target_id: elementId(endNode(r)),
197+ properties: properties(r)
179198 } AS r,
180- {id: elementId(m), name: m.name} AS t
199+ {id: elementId(m), name: m.name, properties: properties(m) } AS t
181200 LIMIT $num
182201 """
183202 results = tx .run (fallback_query , num = int (num ))
184203 formatted_results = {"nodes" : [], "edges" : []}
185204 node_ids = set ()
186205
187206 for item in results :
188- h_node = item ["h" ]
189- t_node = item ["t" ]
207+ h_node = _process_record_props (item ["h" ])
208+ t_node = _process_record_props (item ["t" ])
209+ r_edge = _process_record_props (item ["r" ])
190210
191211 # 避免重复添加节点
192212 if h_node ["id" ] not in node_ids :
@@ -196,7 +216,7 @@ def query(tx, num):
196216 formatted_results ["nodes" ].append (t_node )
197217 node_ids .add (t_node ["id" ])
198218
199- formatted_results ["edges" ].append (item [ "r" ] )
219+ formatted_results ["edges" ].append (r_edge )
200220
201221 return formatted_results
202222
@@ -240,18 +260,47 @@ def _index_exists(tx, index_name):
240260 return True
241261 return False
242262
263+ def _parse_node (node_data ):
264+ """解析节点数据,返回 (name, props)"""
265+ if isinstance (node_data , dict ):
266+ props = node_data .copy ()
267+ name = props .pop ("name" , "" )
268+ return name , props
269+ return str (node_data ), {}
270+
271+ def _parse_relation (rel_data ):
272+ """解析关系数据,返回 (type, props)"""
273+ if isinstance (rel_data , dict ):
274+ props = rel_data .copy ()
275+ rel_type = props .pop ("type" , "" )
276+ return rel_type , props
277+ return str (rel_data ), {}
278+
243279 def _create_graph (tx , data ):
244280 """添加一个三元组"""
245281 for entry in data :
282+ h_name , h_props = _parse_node (entry .get ("h" ))
283+ t_name , t_props = _parse_node (entry .get ("t" ))
284+ r_type , r_props = _parse_relation (entry .get ("r" ))
285+
286+ if not h_name or not t_name or not r_type :
287+ continue
288+
246289 tx .run (
247290 """
248- MERGE (h:Entity:Upload {name: $h})
249- MERGE (t:Entity:Upload {name: $t})
250- MERGE (h)-[r:RELATION {type: $r}]->(t)
291+ MERGE (h:Entity:Upload {name: $h_name})
292+ SET h += $h_props
293+ MERGE (t:Entity:Upload {name: $t_name})
294+ SET t += $t_props
295+ MERGE (h)-[r:RELATION {type: $r_type}]->(t)
296+ SET r += $r_props
251297 """ ,
252- h = entry ["h" ],
253- t = entry ["t" ],
254- r = entry ["r" ],
298+ h_name = h_name ,
299+ h_props = h_props ,
300+ t_name = t_name ,
301+ t_props = t_props ,
302+ r_type = r_type ,
303+ r_props = r_props ,
255304 )
256305
257306 def _create_vector_index (tx , dim ):
@@ -273,6 +322,9 @@ def _get_nodes_without_embedding(tx, entity_names):
273322 # 构建参数字典,将列表转换为"param0"、"param1"等键值对形式
274323 params = {f"param{ i } " : name for i , name in enumerate (entity_names )}
275324
325+ if not params :
326+ return []
327+
276328 # 构建查询参数列表
277329 param_placeholders = ", " .join ([f"${ key } " for key in params .keys ()])
278330
@@ -315,20 +367,24 @@ def _batch_set_embeddings(tx, entity_embedding_pairs):
315367 session .execute_write (_create_vector_index , getattr (cur_embed_info , "dimension" , 1024 ))
316368
317369 # 收集所有需要处理的实体名称,去重
318- all_entities = []
370+ all_entities = set ()
319371 for entry in triples :
320- if entry ["h" ] not in all_entities :
321- all_entities .append (entry ["h" ])
322- if entry ["t" ] not in all_entities :
323- all_entities .append (entry ["t" ])
372+ h_name , _ = _parse_node (entry .get ("h" ))
373+ t_name , _ = _parse_node (entry .get ("t" ))
374+ if h_name :
375+ all_entities .add (h_name )
376+ if t_name :
377+ all_entities .add (t_name )
378+
379+ all_entities_list = list (all_entities )
324380
325381 # 筛选出没有embedding的节点
326- nodes_without_embedding = session .execute_read (_get_nodes_without_embedding , all_entities )
382+ nodes_without_embedding = session .execute_read (_get_nodes_without_embedding , all_entities_list )
327383 if not nodes_without_embedding :
328384 logger .info ("所有实体已有embedding,无需重新计算" )
329385 return
330386
331- logger .info (f"需要为{ len (nodes_without_embedding )} /{ len (all_entities )} 个实体计算embedding" )
387+ logger .info (f"需要为{ len (nodes_without_embedding )} /{ len (all_entities_list )} 个实体计算embedding" )
332388
333389 # 批量处理实体
334390 max_batch_size = 1024 # 限制此部分的主要是内存大小 1024 * 1024 * 4 / 1024 / 1024 = 4GB
@@ -597,30 +653,46 @@ def _query_specific_entity(self, entity_name, kgdb_name="neo4j", hops=2, limit=1
597653
598654 self .use_database (kgdb_name )
599655
656+ def _process_record_props (record ):
657+ """处理记录中的属性:扁平化 properties 并移除 embedding"""
658+ if record is None :
659+ return None
660+
661+ # 复制一份以避免修改原字典
662+ data = dict (record )
663+ props = data .pop ("properties" , {}) or {}
664+
665+ # 移除 embedding
666+ if "embedding" in props :
667+ del props ["embedding" ]
668+
669+ # 合并属性(优先保留原字典中的 id, name, type 等核心字段)
670+ return {** props , ** data }
671+
600672 def query (tx , entity_name , hops , limit ):
601673 try :
602674 query_str = """
603675 WITH [
604676 // 1跳出边
605677 [(n {name: $entity_name})-[r1]->(m1) |
606- {h: {id: elementId(n), name: n.name},
607- r: {id: elementId(r1), type: r1.type, source_id: elementId(n), target_id: elementId(m1)},
608- t: {id: elementId(m1), name: m1.name}}],
678+ {h: {id: elementId(n), name: n.name, properties: properties(n) },
679+ r: {id: elementId(r1), type: r1.type, source_id: elementId(n), target_id: elementId(m1), properties: properties(r1) },
680+ t: {id: elementId(m1), name: m1.name, properties: properties(m1) }}],
609681 // 2跳出边
610682 [(n {name: $entity_name})-[r1]->(m1)-[r2]->(m2) |
611- {h: {id: elementId(m1), name: m1.name},
612- r: {id: elementId(r2), type: r2.type, source_id: elementId(m1), target_id: elementId(m2)},
613- t: {id: elementId(m2), name: m2.name}}],
683+ {h: {id: elementId(m1), name: m1.name, properties: properties(m1) },
684+ r: {id: elementId(r2), type: r2.type, source_id: elementId(m1), target_id: elementId(m2), properties: properties(r2) },
685+ t: {id: elementId(m2), name: m2.name, properties: properties(m2) }}],
614686 // 1跳入边
615687 [(m1)-[r1]->(n {name: $entity_name}) |
616- {h: {id: elementId(m1), name: m1.name},
617- r: {id: elementId(r1), type: r1.type, source_id: elementId(m1), target_id: elementId(n)},
618- t: {id: elementId(n), name: n.name}}],
688+ {h: {id: elementId(m1), name: m1.name, properties: properties(m1) },
689+ r: {id: elementId(r1), type: r1.type, source_id: elementId(m1), target_id: elementId(n), properties: properties(r1) },
690+ t: {id: elementId(n), name: n.name, properties: properties(n) }}],
619691 // 2跳入边
620692 [(m2)-[r2]->(m1)-[r1]->(n {name: $entity_name}) |
621- {h: {id: elementId(m2), name: m2.name},
622- r: {id: elementId(r2), type: r2.type, source_id: elementId(m2), target_id: elementId(m1)},
623- t: {id: elementId(m1), name: m1.name}}]
693+ {h: {id: elementId(m2), name: m2.name, properties: properties(m2) },
694+ r: {id: elementId(r2), type: r2.type, source_id: elementId(m2), target_id: elementId(m1), properties: properties(r2) },
695+ t: {id: elementId(m1), name: m1.name, properties: properties(m1) }}]
624696 ] AS all_results
625697 UNWIND all_results AS result_list
626698 UNWIND result_list AS item
@@ -636,9 +708,13 @@ def query(tx, entity_name, hops, limit):
636708 formatted_results = {"nodes" : [], "edges" : [], "triples" : []}
637709
638710 for item in results :
639- formatted_results ["nodes" ].extend ([item ["h" ], item ["t" ]])
640- formatted_results ["edges" ].append (item ["r" ])
641- formatted_results ["triples" ].append ((item ["h" ]["name" ], item ["r" ]["type" ], item ["t" ]["name" ]))
711+ h = _process_record_props (item ["h" ])
712+ r = _process_record_props (item ["r" ])
713+ t = _process_record_props (item ["t" ])
714+
715+ formatted_results ["nodes" ].extend ([h , t ])
716+ formatted_results ["edges" ].append (r )
717+ formatted_results ["triples" ].append ((h ["name" ], r ["type" ], t ["name" ]))
642718
643719 logger .debug (f"Query Results: { results } " )
644720 return formatted_results
0 commit comments