Skip to content

Commit 5f60478

Browse files
committed
refactor: 优化代码格式和结构,改进错误提示信息
- 统一代码格式,移除多余空格和注释 - 改进MySQL连接错误提示信息 - 优化GraphDatabase查询格式和响应结构 - 简化测试用例参数传递方式 - 调整前端GraphCanvas组件样式
1 parent 75a7d46 commit 5f60478

File tree

9 files changed

+128
-140
lines changed

9 files changed

+128
-140
lines changed

server/routers/graph_router.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -57,29 +57,33 @@ async def get_graphs(current_user: User = Depends(get_admin_user)):
5757
# 1. 获取默认 Neo4j 图谱信息
5858
neo4j_info = graph_base.get_graph_info()
5959
if neo4j_info:
60-
graphs.append({
61-
"id": "neo4j",
62-
"name": "默认图谱",
63-
"type": "neo4j",
64-
"description": "Default graph database for uploaded documents",
65-
"status": neo4j_info.get("status", "unknown"),
66-
"created_at": neo4j_info.get("last_updated"),
67-
"node_count": neo4j_info.get("entity_count", 0),
68-
"edge_count": neo4j_info.get("relationship_count", 0)
69-
})
60+
graphs.append(
61+
{
62+
"id": "neo4j",
63+
"name": "默认图谱",
64+
"type": "neo4j",
65+
"description": "Default graph database for uploaded documents",
66+
"status": neo4j_info.get("status", "unknown"),
67+
"created_at": neo4j_info.get("last_updated"),
68+
"node_count": neo4j_info.get("entity_count", 0),
69+
"edge_count": neo4j_info.get("relationship_count", 0),
70+
}
71+
)
7072

7173
# 2. 获取 LightRAG 数据库信息
7274
lightrag_dbs = knowledge_base.get_lightrag_databases()
7375
for db in lightrag_dbs:
74-
graphs.append({
75-
"id": db.get("db_id"),
76-
"name": db.get("name"),
77-
"type": "lightrag",
78-
"description": db.get("description"),
79-
"status": "active", # LightRAG DBs are usually active if listed
80-
"created_at": db.get("created_at"),
81-
"metadata": db
82-
})
76+
graphs.append(
77+
{
78+
"id": db.get("db_id"),
79+
"name": db.get("name"),
80+
"type": "lightrag",
81+
"description": db.get("description"),
82+
"status": "active", # LightRAG DBs are usually active if listed
83+
"created_at": db.get("created_at"),
84+
"metadata": db,
85+
}
86+
)
8387

8488
return {"success": True, "data": graphs}
8589

@@ -118,7 +122,7 @@ async def get_subgraph(
118122
keyword=node_label,
119123
max_depth=max_depth,
120124
max_nodes=max_nodes,
121-
kgdb_name=db_id if not knowledge_base.is_lightrag_database(db_id) else "neo4j"
125+
kgdb_name=db_id if not knowledge_base.is_lightrag_database(db_id) else "neo4j",
122126
)
123127

124128
return {
@@ -136,8 +140,7 @@ async def get_subgraph(
136140

137141
@graph.get("/labels")
138142
async def get_graph_labels(
139-
db_id: str = Query(..., description="知识图谱ID"),
140-
current_user: User = Depends(get_admin_user)
143+
db_id: str = Query(..., description="知识图谱ID"), current_user: User = Depends(get_admin_user)
141144
):
142145
"""
143146
获取图谱的所有标签
@@ -154,8 +157,7 @@ async def get_graph_labels(
154157

155158
@graph.get("/stats")
156159
async def get_graph_stats(
157-
db_id: str = Query(..., description="知识图谱ID"),
158-
current_user: User = Depends(get_admin_user)
160+
db_id: str = Query(..., description="知识图谱ID"), current_user: User = Depends(get_admin_user)
159161
):
160162
"""
161163
获取图谱统计信息
@@ -175,23 +177,22 @@ async def get_graph_stats(
175177
entity_types[entity_type] = entity_types.get(entity_type, 0) + 1
176178

177179
entity_types_list = [
178-
{"type": k, "count": v}
179-
for k, v in sorted(entity_types.items(), key=lambda x: x[1], reverse=True)
180+
{"type": k, "count": v} for k, v in sorted(entity_types.items(), key=lambda x: x[1], reverse=True)
180181
]
181182

182183
return {
183184
"success": True,
184185
"data": {
185186
"total_nodes": len(knowledge_graph.nodes),
186187
"total_edges": len(knowledge_graph.edges),
187-
"entity_types": entity_types_list
188-
}
188+
"entity_types": entity_types_list,
189+
},
189190
}
190191
else:
191192
# Neo4j stats
192193
info = graph_base.get_graph_info(graph_name=db_id)
193194
if not info:
194-
raise HTTPException(status_code=404, detail="Graph info not found")
195+
raise HTTPException(status_code=404, detail="Graph info not found")
195196

196197
return {
197198
"success": True,
@@ -200,11 +201,8 @@ async def get_graph_stats(
200201
"total_edges": info.get("relationship_count", 0),
201202
# Neo4j info currently returns 'labels' list, not counts per label.
202203
# Improving this would require updating GraphDatabase.get_graph_info
203-
"entity_types": [
204-
{"type": label, "count": "N/A"}
205-
for label in info.get("labels", [])
206-
]
207-
}
204+
"entity_types": [{"type": label, "count": "N/A"} for label in info.get("labels", [])],
205+
},
208206
}
209207

210208
except Exception as e:
@@ -227,11 +225,7 @@ async def get_lightrag_subgraph(
227225
):
228226
"""(Deprecated) Use /graph/subgraph instead"""
229227
return await get_subgraph(
230-
db_id=db_id,
231-
node_label=node_label,
232-
max_depth=max_depth,
233-
max_nodes=max_nodes,
234-
current_user=current_user
228+
db_id=db_id, node_label=node_label, max_depth=max_depth, max_nodes=max_nodes, current_user=current_user
235229
)
236230

237231

src/agents/common/toolkits/mysql/tools.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def get_connection_manager() -> MySQLConnectionManager:
3838
required_keys = ["host", "user", "password", "database"]
3939
for key in required_keys:
4040
if not mysql_config[key]:
41-
raise MySQLConnectionError(f"MySQL configuration missing required key: {key}")
41+
raise MySQLConnectionError(
42+
f"MySQL configuration missing required key: {key}, please check your environment variables."
43+
)
4244

4345
_connection_manager = MySQLConnectionManager(mysql_config)
4446
return _connection_manager

src/knowledge/adapters/lightrag.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def normalize_node(self, raw_node: Any) -> dict[str, Any]:
6363
# 优先使用 entity_id 作为显示名称,因为 Neo4j 中 LightRAG 存储的实体名称在 entity_id 字段
6464
# 如果不存在,则回退到 id
6565
name = properties.get("entity_id", node_id)
66-
66+
6767
# 尝试从 properties 获取 entity_type,或者从 labels 中推断(排除 kb_ 前缀的 label)
6868
entity_type = properties.get("entity_type", "unknown")
6969
if entity_type == "unknown" and labels:
@@ -73,12 +73,7 @@ def normalize_node(self, raw_node: Any) -> dict[str, Any]:
7373
break
7474

7575
return self._create_standard_node(
76-
node_id=node_id,
77-
name=name,
78-
entity_type=entity_type,
79-
labels=labels,
80-
properties=properties,
81-
source="lightrag"
76+
node_id=node_id, name=name, entity_type=entity_type, labels=labels, properties=properties, source="lightrag"
8277
)
8378

8479
def normalize_edge(self, raw_edge: Any) -> dict[str, Any]:
@@ -102,7 +97,7 @@ def normalize_edge(self, raw_edge: Any) -> dict[str, Any]:
10297
properties = getattr(raw_edge, "properties", {})
10398
if not properties and hasattr(raw_edge, "get"):
10499
properties = raw_edge.get("properties", {})
105-
100+
106101
# 优化边的显示类型
107102
# LightRAG 的边类型通常是 "DIRECTED",具体含义在 keywords 或 description 中
108103
display_type = edge_type
@@ -116,14 +111,10 @@ def normalize_edge(self, raw_edge: Any) -> dict[str, Any]:
116111
if len(desc) < 20:
117112
display_type = desc
118113
else:
119-
display_type = "related" # fallback
114+
display_type = "related" # fallback
120115

121116
return self._create_standard_edge(
122-
edge_id=edge_id,
123-
source_id=source,
124-
target_id=target,
125-
edge_type=display_type,
126-
properties=properties
117+
edge_id=edge_id, source_id=source, target_id=target, edge_type=display_type, properties=properties
127118
)
128119

129120
async def get_labels(self) -> list[str]:

src/knowledge/adapters/upload.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,12 @@ def __init__(self, graph_db_instance: GraphDatabase, config: dict[str, Any] = No
1919

2020
async def query_nodes(self, keyword: str, **kwargs) -> dict[str, Any]:
2121
params = self._normalize_query_params(keyword, kwargs)
22-
22+
2323
# 如果关键词是 "*" 或者为空,则执行采样查询
2424
if not params["keyword"] or params["keyword"] == "*":
2525
# 映射 max_nodes 到 num
2626
num = kwargs.get("max_nodes", 100)
27-
raw_results = self.graph_db.get_sample_nodes(
28-
kgdb_name=params.get("kgdb_name", "neo4j"),
29-
num=num
30-
)
27+
raw_results = self.graph_db.get_sample_nodes(kgdb_name=params.get("kgdb_name", "neo4j"), num=num)
3128
else:
3229
# 否则执行关键词搜索
3330
# graph_db.query_node is sync
@@ -66,7 +63,7 @@ def normalize_node(self, raw_node: Any) -> dict[str, Any]:
6663
entity_type="entity",
6764
labels=["Entity", "Upload"],
6865
properties=raw_node,
69-
source="upload"
66+
source="upload",
7067
)
7168

7269
def normalize_edge(self, raw_edge: Any) -> dict[str, Any]:
@@ -83,7 +80,7 @@ def normalize_edge(self, raw_edge: Any) -> dict[str, Any]:
8380
source_id=raw_edge.get("source_id"),
8481
target_id=raw_edge.get("target_id"),
8582
edge_type=raw_edge.get("type"),
86-
properties=raw_edge
83+
properties=raw_edge,
8784
)
8885

8986
async def get_labels(self) -> list[str]:

src/knowledge/graph.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ def _process_record_props(record):
6868
"""处理记录中的属性:扁平化 properties 并移除 embedding"""
6969
if record is None:
7070
return None
71-
71+
7272
# 复制一份以避免修改原字典
7373
data = dict(record)
7474
props = data.pop("properties", {}) or {}
75-
75+
7676
# 移除 embedding
7777
if "embedding" in props:
7878
del props["embedding"]
79-
79+
8080
# 合并属性(优先保留原字典中的 id, name, type 等核心字段)
8181
return {**props, **data}
8282

@@ -657,15 +657,15 @@ def _process_record_props(record):
657657
"""处理记录中的属性:扁平化 properties 并移除 embedding"""
658658
if record is None:
659659
return None
660-
660+
661661
# 复制一份以避免修改原字典
662662
data = dict(record)
663663
props = data.pop("properties", {}) or {}
664-
664+
665665
# 移除 embedding
666666
if "embedding" in props:
667667
del props["embedding"]
668-
668+
669669
# 合并属性(优先保留原字典中的 id, name, type 等核心字段)
670670
return {**props, **data}
671671

@@ -676,22 +676,46 @@ def query(tx, entity_name, hops, limit):
676676
// 1跳出边
677677
[(n {name: $entity_name})-[r1]->(m1) |
678678
{h: {id: elementId(n), name: n.name, properties: properties(n)},
679-
r: {id: elementId(r1), type: r1.type, source_id: elementId(n), target_id: elementId(m1), properties: properties(r1)},
679+
r: {
680+
id: elementId(r1),
681+
type: r1.type,
682+
source_id: elementId(n),
683+
target_id: elementId(m1),
684+
properties: properties(r1)
685+
},
680686
t: {id: elementId(m1), name: m1.name, properties: properties(m1)}}],
681687
// 2跳出边
682688
[(n {name: $entity_name})-[r1]->(m1)-[r2]->(m2) |
683689
{h: {id: elementId(m1), name: m1.name, properties: properties(m1)},
684-
r: {id: elementId(r2), type: r2.type, source_id: elementId(m1), target_id: elementId(m2), properties: properties(r2)},
690+
r: {
691+
id: elementId(r2),
692+
type: r2.type,
693+
source_id: elementId(m1),
694+
target_id: elementId(m2),
695+
properties: properties(r2)
696+
},
685697
t: {id: elementId(m2), name: m2.name, properties: properties(m2)}}],
686698
// 1跳入边
687699
[(m1)-[r1]->(n {name: $entity_name}) |
688700
{h: {id: elementId(m1), name: m1.name, properties: properties(m1)},
689-
r: {id: elementId(r1), type: r1.type, source_id: elementId(m1), target_id: elementId(n), properties: properties(r1)},
701+
r: {
702+
id: elementId(r1),
703+
type: r1.type,
704+
source_id: elementId(m1),
705+
target_id: elementId(n),
706+
properties: properties(r1)
707+
},
690708
t: {id: elementId(n), name: n.name, properties: properties(n)}}],
691709
// 2跳入边
692710
[(m2)-[r2]->(m1)-[r1]->(n {name: $entity_name}) |
693711
{h: {id: elementId(m2), name: m2.name, properties: properties(m2)},
694-
r: {id: elementId(r2), type: r2.type, source_id: elementId(m2), target_id: elementId(m1), properties: properties(r2)},
712+
r: {
713+
id: elementId(r2),
714+
type: r2.type,
715+
source_id: elementId(m2),
716+
target_id: elementId(m1),
717+
properties: properties(r2)
718+
},
695719
t: {id: elementId(m1), name: m1.name, properties: properties(m1)}}]
696720
] AS all_results
697721
UNWIND all_results AS result_list
@@ -711,7 +735,7 @@ def query(tx, entity_name, hops, limit):
711735
h = _process_record_props(item["h"])
712736
r = _process_record_props(item["r"])
713737
t = _process_record_props(item["t"])
714-
738+
715739
formatted_results["nodes"].extend([h, t])
716740
formatted_results["edges"].append(r)
717741
formatted_results["triples"].append((h["name"], r["type"], t["name"]))

src/knowledge/implementations/lightrag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def _get_embedding_func(self, embed_info: dict):
235235
model=model_name,
236236
api_key=config_dict["api_key"],
237237
base_url=config_dict["base_url"].replace("/embeddings", ""),
238-
)
238+
),
239239
)
240240

241241
async def add_content(self, db_id: str, items: list[str], params: dict | None = None) -> list[dict]:

0 commit comments

Comments
 (0)