Skip to content

Commit cb9edfd

Browse files
committed
feat(graph): 支持带属性的知识图谱节点和关系导入
- 扩展三元组导入功能,支持节点和关系的属性存储 - 新增测试数据文件和单元测试验证功能 - 更新文档说明支持新旧两种数据格式 - 优化查询结果处理,保留并返回节点和关系的属性
1 parent 3d63418 commit cb9edfd

File tree

4 files changed

+242
-46
lines changed

4 files changed

+242
-46
lines changed

docs/latest/intro/knowledge-base.md

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,32 @@ LIMIT $num
9797

9898
### 1. 以三元组形式导入
9999

100+
系统支持通过网页导入 `jsonl` 格式的知识图谱数据,支持**简单三元组****带属性三元组**两种格式。
100101

101-
系统支持通过网页导入 `jsonl` 格式的知识图谱数据
102+
**简单格式(兼容旧版)**
102103

103104
```jsonl
104105
{"h": "北京", "t": "中国", "r": "首都"}
105106
{"h": "上海", "t": "中国", "r": "直辖市"}
106-
{"h": "深圳", "t": "广东", "r": "省会"}
107107
```
108108

109-
**格式说明**,每行一个三元组,系统自动验证数据格式,并自动导入到 Neo4j 数据库,添加 `Upload``Entity``Relation` 标签,会自动处理重复的三元组。
109+
**扩展格式(支持属性)**
110+
111+
支持 `h`(头节点)、`t`(尾节点)和 `r`(关系)为对象结构,其中:
112+
- 节点对象必须包含 `name` 字段。
113+
- 关系对象必须包含 `type` 字段。
114+
- 其他字段将作为**属性**存储在 Neo4j 中。
115+
116+
```jsonl
117+
{"h": {"name": "孙悟空", "title": "齐天大圣", "weapon": "如意金箍棒"}, "t": {"name": "唐僧", "species": ""}, "r": {"type": "徒弟", "order": 1}}
118+
{"h": "猪八戒", "t": {"name": "唐僧"}, "r": {"type": "徒弟", "order": 2}}
119+
```
120+
121+
**格式说明**
122+
- 每行一个数据项。
123+
- 系统自动验证数据格式,并自动导入到 Neo4j 数据库。
124+
- 自动添加 `Upload``Entity` 标签(节点)和 `RELATION` 类型(关系)。
125+
- 自动处理重复实体和关系,并合并属性。
110126

111127
Neo4j 访问信息可以参考 `docker-compose.yml` 中配置对应的环境变量来覆盖。
112128

@@ -116,7 +132,9 @@ Neo4j 访问信息可以参考 `docker-compose.yml` 中配置对应的环境变
116132
- **连接地址**: bolt://localhost:7687
117133

118134
::: tip 测试数据
119-
可以使用 `test/data/A_Dream_of_Red_Mansions_tiny.jsonl` 文件进行测试导入。
135+
可以使用以下文件进行测试导入:
136+
- 简单格式:`test/data/A_Dream_of_Red_Mansions_tiny.jsonl`
137+
- 扩展属性格式:`test/data/complex_graph_test.jsonl`
120138
:::
121139

122140
### 2. 接入已有 Neo4j 实例

src/knowledge/graph.py

Lines changed: 118 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@ def get_sample_nodes(self, kgdb_name="neo4j", num=50):
6464
assert self.driver is not None, "Database is not connected"
6565
self.use_database(kgdb_name)
6666

67+
def _process_record_props(record):
68+
"""处理记录中的属性:扁平化 properties 并移除 embedding"""
69+
if record is None:
70+
return None
71+
72+
# 复制一份以避免修改原字典
73+
data = dict(record)
74+
props = data.pop("properties", {}) or {}
75+
76+
# 移除 embedding
77+
if "embedding" in props:
78+
del props["embedding"]
79+
80+
# 合并属性(优先保留原字典中的 id, name, type 等核心字段)
81+
return {**props, **data}
82+
6783
def query(tx, num):
6884
"""Note: 使用连通性查询获取集中的节点子图"""
6985
# 首先尝试获取一个连通的子图
@@ -104,17 +120,18 @@ def query(tx, num):
104120
OPTIONAL MATCH (n)-[rel]-(m)
105121
WHERE m IN final_nodes AND elementId(n) < elementId(m)
106122
RETURN
107-
{id: elementId(n), name: n.name} AS h,
123+
{id: elementId(n), name: n.name, properties: properties(n)} AS h,
108124
CASE WHEN rel IS NOT NULL THEN
109125
{
110126
id: elementId(rel),
111127
type: rel.type,
112128
source_id: elementId(startNode(rel)),
113-
target_id: elementId(endNode(rel))
129+
target_id: elementId(endNode(rel)),
130+
properties: properties(rel)
114131
}
115132
ELSE null END AS r,
116133
CASE WHEN m IS NOT NULL THEN
117-
{id: elementId(m), name: m.name}
134+
{id: elementId(m), name: m.name, properties: properties(m)}
118135
ELSE null END AS t
119136
"""
120137

@@ -124,7 +141,7 @@ def query(tx, num):
124141
node_ids = set()
125142

126143
for item in results:
127-
h_node = item["h"]
144+
h_node = _process_record_props(item["h"])
128145

129146
# 始终添加头节点
130147
if h_node["id"] not in node_ids:
@@ -133,14 +150,15 @@ def query(tx, num):
133150

134151
# 只有当边和尾节点都存在时才处理
135152
if item["r"] is not None and item["t"] is not None:
136-
t_node = item["t"]
153+
t_node = _process_record_props(item["t"])
154+
r_edge = _process_record_props(item["r"])
137155

138156
# 避免重复添加尾节点
139157
if t_node["id"] not in node_ids:
140158
formatted_results["nodes"].append(t_node)
141159
node_ids.add(t_node["id"])
142160

143-
formatted_results["edges"].append(item["r"])
161+
formatted_results["edges"].append(r_edge)
144162

145163
# 如果连通查询返回的节点数不足,补充更多节点
146164
if len(formatted_results["nodes"]) < num:
@@ -150,14 +168,14 @@ def query(tx, num):
150168
supplement_query = """
151169
MATCH (n:Entity)
152170
WHERE NOT elementId(n) IN $existing_ids
153-
RETURN {id: elementId(n), name: n.name} AS node
171+
RETURN {id: elementId(n), name: n.name, properties: properties(n)} AS node
154172
LIMIT $count
155173
"""
156174

157175
supplement_results = tx.run(supplement_query, existing_ids=list(node_ids), count=remaining_count)
158176

159177
for item in supplement_results:
160-
node = item["node"]
178+
node = _process_record_props(item["node"])
161179
formatted_results["nodes"].append(node)
162180
node_ids.add(node["id"])
163181

@@ -170,23 +188,25 @@ def query(tx, num):
170188
MATCH (n:Entity)-[r]-(m:Entity)
171189
WHERE elementId(n) < elementId(m)
172190
RETURN
173-
{id: elementId(n), name: n.name} AS h,
191+
{id: elementId(n), name: n.name, properties: properties(n)} AS h,
174192
{
175193
id: elementId(r),
176194
type: r.type,
177195
source_id: elementId(startNode(r)),
178-
target_id: elementId(endNode(r))
196+
target_id: elementId(endNode(r)),
197+
properties: properties(r)
179198
} AS r,
180-
{id: elementId(m), name: m.name} AS t
199+
{id: elementId(m), name: m.name, properties: properties(m)} AS t
181200
LIMIT $num
182201
"""
183202
results = tx.run(fallback_query, num=int(num))
184203
formatted_results = {"nodes": [], "edges": []}
185204
node_ids = set()
186205

187206
for item in results:
188-
h_node = item["h"]
189-
t_node = item["t"]
207+
h_node = _process_record_props(item["h"])
208+
t_node = _process_record_props(item["t"])
209+
r_edge = _process_record_props(item["r"])
190210

191211
# 避免重复添加节点
192212
if h_node["id"] not in node_ids:
@@ -196,7 +216,7 @@ def query(tx, num):
196216
formatted_results["nodes"].append(t_node)
197217
node_ids.add(t_node["id"])
198218

199-
formatted_results["edges"].append(item["r"])
219+
formatted_results["edges"].append(r_edge)
200220

201221
return formatted_results
202222

@@ -240,18 +260,47 @@ def _index_exists(tx, index_name):
240260
return True
241261
return False
242262

263+
def _parse_node(node_data):
264+
"""解析节点数据,返回 (name, props)"""
265+
if isinstance(node_data, dict):
266+
props = node_data.copy()
267+
name = props.pop("name", "")
268+
return name, props
269+
return str(node_data), {}
270+
271+
def _parse_relation(rel_data):
272+
"""解析关系数据,返回 (type, props)"""
273+
if isinstance(rel_data, dict):
274+
props = rel_data.copy()
275+
rel_type = props.pop("type", "")
276+
return rel_type, props
277+
return str(rel_data), {}
278+
243279
def _create_graph(tx, data):
244280
"""添加一个三元组"""
245281
for entry in data:
282+
h_name, h_props = _parse_node(entry.get("h"))
283+
t_name, t_props = _parse_node(entry.get("t"))
284+
r_type, r_props = _parse_relation(entry.get("r"))
285+
286+
if not h_name or not t_name or not r_type:
287+
continue
288+
246289
tx.run(
247290
"""
248-
MERGE (h:Entity:Upload {name: $h})
249-
MERGE (t:Entity:Upload {name: $t})
250-
MERGE (h)-[r:RELATION {type: $r}]->(t)
291+
MERGE (h:Entity:Upload {name: $h_name})
292+
SET h += $h_props
293+
MERGE (t:Entity:Upload {name: $t_name})
294+
SET t += $t_props
295+
MERGE (h)-[r:RELATION {type: $r_type}]->(t)
296+
SET r += $r_props
251297
""",
252-
h=entry["h"],
253-
t=entry["t"],
254-
r=entry["r"],
298+
h_name=h_name,
299+
h_props=h_props,
300+
t_name=t_name,
301+
t_props=t_props,
302+
r_type=r_type,
303+
r_props=r_props,
255304
)
256305

257306
def _create_vector_index(tx, dim):
@@ -273,6 +322,9 @@ def _get_nodes_without_embedding(tx, entity_names):
273322
# 构建参数字典,将列表转换为"param0"、"param1"等键值对形式
274323
params = {f"param{i}": name for i, name in enumerate(entity_names)}
275324

325+
if not params:
326+
return []
327+
276328
# 构建查询参数列表
277329
param_placeholders = ", ".join([f"${key}" for key in params.keys()])
278330

@@ -315,20 +367,24 @@ def _batch_set_embeddings(tx, entity_embedding_pairs):
315367
session.execute_write(_create_vector_index, getattr(cur_embed_info, "dimension", 1024))
316368

317369
# 收集所有需要处理的实体名称,去重
318-
all_entities = []
370+
all_entities = set()
319371
for entry in triples:
320-
if entry["h"] not in all_entities:
321-
all_entities.append(entry["h"])
322-
if entry["t"] not in all_entities:
323-
all_entities.append(entry["t"])
372+
h_name, _ = _parse_node(entry.get("h"))
373+
t_name, _ = _parse_node(entry.get("t"))
374+
if h_name:
375+
all_entities.add(h_name)
376+
if t_name:
377+
all_entities.add(t_name)
378+
379+
all_entities_list = list(all_entities)
324380

325381
# 筛选出没有embedding的节点
326-
nodes_without_embedding = session.execute_read(_get_nodes_without_embedding, all_entities)
382+
nodes_without_embedding = session.execute_read(_get_nodes_without_embedding, all_entities_list)
327383
if not nodes_without_embedding:
328384
logger.info("所有实体已有embedding,无需重新计算")
329385
return
330386

331-
logger.info(f"需要为{len(nodes_without_embedding)}/{len(all_entities)}个实体计算embedding")
387+
logger.info(f"需要为{len(nodes_without_embedding)}/{len(all_entities_list)}个实体计算embedding")
332388

333389
# 批量处理实体
334390
max_batch_size = 1024 # 限制此部分的主要是内存大小 1024 * 1024 * 4 / 1024 / 1024 = 4GB
@@ -597,30 +653,46 @@ def _query_specific_entity(self, entity_name, kgdb_name="neo4j", hops=2, limit=1
597653

598654
self.use_database(kgdb_name)
599655

656+
def _process_record_props(record):
657+
"""处理记录中的属性:扁平化 properties 并移除 embedding"""
658+
if record is None:
659+
return None
660+
661+
# 复制一份以避免修改原字典
662+
data = dict(record)
663+
props = data.pop("properties", {}) or {}
664+
665+
# 移除 embedding
666+
if "embedding" in props:
667+
del props["embedding"]
668+
669+
# 合并属性(优先保留原字典中的 id, name, type 等核心字段)
670+
return {**props, **data}
671+
600672
def query(tx, entity_name, hops, limit):
601673
try:
602674
query_str = """
603675
WITH [
604676
// 1跳出边
605677
[(n {name: $entity_name})-[r1]->(m1) |
606-
{h: {id: elementId(n), name: n.name},
607-
r: {id: elementId(r1), type: r1.type, source_id: elementId(n), target_id: elementId(m1)},
608-
t: {id: elementId(m1), name: m1.name}}],
678+
{h: {id: elementId(n), name: n.name, properties: properties(n)},
679+
r: {id: elementId(r1), type: r1.type, source_id: elementId(n), target_id: elementId(m1), properties: properties(r1)},
680+
t: {id: elementId(m1), name: m1.name, properties: properties(m1)}}],
609681
// 2跳出边
610682
[(n {name: $entity_name})-[r1]->(m1)-[r2]->(m2) |
611-
{h: {id: elementId(m1), name: m1.name},
612-
r: {id: elementId(r2), type: r2.type, source_id: elementId(m1), target_id: elementId(m2)},
613-
t: {id: elementId(m2), name: m2.name}}],
683+
{h: {id: elementId(m1), name: m1.name, properties: properties(m1)},
684+
r: {id: elementId(r2), type: r2.type, source_id: elementId(m1), target_id: elementId(m2), properties: properties(r2)},
685+
t: {id: elementId(m2), name: m2.name, properties: properties(m2)}}],
614686
// 1跳入边
615687
[(m1)-[r1]->(n {name: $entity_name}) |
616-
{h: {id: elementId(m1), name: m1.name},
617-
r: {id: elementId(r1), type: r1.type, source_id: elementId(m1), target_id: elementId(n)},
618-
t: {id: elementId(n), name: n.name}}],
688+
{h: {id: elementId(m1), name: m1.name, properties: properties(m1)},
689+
r: {id: elementId(r1), type: r1.type, source_id: elementId(m1), target_id: elementId(n), properties: properties(r1)},
690+
t: {id: elementId(n), name: n.name, properties: properties(n)}}],
619691
// 2跳入边
620692
[(m2)-[r2]->(m1)-[r1]->(n {name: $entity_name}) |
621-
{h: {id: elementId(m2), name: m2.name},
622-
r: {id: elementId(r2), type: r2.type, source_id: elementId(m2), target_id: elementId(m1)},
623-
t: {id: elementId(m1), name: m1.name}}]
693+
{h: {id: elementId(m2), name: m2.name, properties: properties(m2)},
694+
r: {id: elementId(r2), type: r2.type, source_id: elementId(m2), target_id: elementId(m1), properties: properties(r2)},
695+
t: {id: elementId(m1), name: m1.name, properties: properties(m1)}}]
624696
] AS all_results
625697
UNWIND all_results AS result_list
626698
UNWIND result_list AS item
@@ -636,9 +708,13 @@ def query(tx, entity_name, hops, limit):
636708
formatted_results = {"nodes": [], "edges": [], "triples": []}
637709

638710
for item in results:
639-
formatted_results["nodes"].extend([item["h"], item["t"]])
640-
formatted_results["edges"].append(item["r"])
641-
formatted_results["triples"].append((item["h"]["name"], item["r"]["type"], item["t"]["name"]))
711+
h = _process_record_props(item["h"])
712+
r = _process_record_props(item["r"])
713+
t = _process_record_props(item["t"])
714+
715+
formatted_results["nodes"].extend([h, t])
716+
formatted_results["edges"].append(r)
717+
formatted_results["triples"].append((h["name"], r["type"], t["name"]))
642718

643719
logger.debug(f"Query Results: {results}")
644720
return formatted_results

test/data/complex_graph_test.jsonl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{"h": {"name": "孙悟空", "title": "齐天大圣", "weapon": "如意金箍棒", "species": ""}, "t": {"name": "唐僧", "title": "旃檀功德佛", "species": ""}, "r": {"type": "徒弟", "order": 1}}
2+
{"h": {"name": "猪八戒", "title": "天蓬元帅", "weapon": "九齿钉耙", "species": ""}, "t": {"name": "唐僧"}, "r": {"type": "徒弟", "order": 2}}
3+
{"h": {"name": "沙悟净", "title": "卷帘大将", "weapon": "降妖宝杖"}, "t": {"name": "唐僧"}, "r": {"type": "徒弟", "order": 3}}
4+
{"h": {"name": "白龙马", "origin": "西海龙宫"}, "t": {"name": "唐僧"}, "r": {"type": "坐骑", "original_form": ""}}
5+
{"h": "孙悟空", "t": {"name": "猪八戒"}, "r": {"type": "师兄弟", "relationship": "conflict/cooperate"}}
6+
{"h": {"name": "如来佛祖", "location": "西天雷音寺"}, "t": {"name": "孙悟空"}, "r": {"type": "压制", "tool": "五指山", "duration": "500年"}}

0 commit comments

Comments
 (0)