Skip to content

Commit ad4474b

Browse files
committed
backend
1 parent 45a6a33 commit ad4474b

File tree

12 files changed

+1218
-663
lines changed

12 files changed

+1218
-663
lines changed

aigraphx/repositories/neo4j_repo.py

Lines changed: 90 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -549,26 +549,26 @@ async def get_paper_neighborhood(self, pwc_id: str) -> Optional[Dict[str, Any]]:
549549
async with self.driver.session() as session:
550550
result = await session.run(query, parameters)
551551
record = await result.single()
552-
552+
553553
if not record or not record.get("paper"):
554554
logger.warning(f"Paper with pwc_id {pwc_id} not found in Neo4j.")
555555
return None # Paper itself not found
556-
556+
557557
# 提取结果
558558
paper_node = dict(record["paper"])
559-
559+
560560
# 转换节点集合为字典列表
561561
authors = [dict(author) for author in record["authors"] if author]
562562
tasks = [dict(task) for task in record["tasks"] if task]
563563
datasets = [dict(dataset) for dataset in record["datasets"] if dataset]
564564
repositories = [dict(repo) for repo in record["repositories"] if repo]
565565
methods = [dict(method) for method in record["methods"] if method]
566566
models = [dict(model) for model in record["models"] if model]
567-
567+
568568
# 取第一个area节点(如果存在)
569569
areas = [dict(area) for area in record["areas"] if area]
570570
area = areas[0] if areas else None
571-
571+
572572
# 构建返回结果
573573
return {
574574
"paper": paper_node,
@@ -578,9 +578,9 @@ async def get_paper_neighborhood(self, pwc_id: str) -> Optional[Dict[str, Any]]:
578578
"repositories": repositories,
579579
"area": area,
580580
"methods": methods,
581-
"models": models
581+
"models": models,
582582
}
583-
583+
584584
except Exception as e:
585585
logger.error(f"Error fetching neighborhood for paper {pwc_id}: {e}")
586586
logger.error(traceback.format_exc())
@@ -617,14 +617,18 @@ async def _run_link_batch_tx(tx: AsyncManagedTransaction) -> None:
617617
try:
618618
async with self.driver.session() as session:
619619
await session.execute_write(_run_link_batch_tx)
620-
logger.info(f"Successfully processed model-paper link batch of {len(links)} links.")
620+
logger.info(
621+
f"Successfully processed model-paper link batch of {len(links)} links."
622+
)
621623
except Exception as e:
622624
logger.error(f"Failed to link models to papers in batch: {e}")
623625
logger.error(traceback.format_exc())
624626
raise
625627

626628
# --- NEW Method: Save Papers by Arxiv ID Batch (for those without pwc_id) ---
627-
async def save_papers_by_arxiv_batch(self, papers_data: List[Dict[str, Any]]) -> None:
629+
async def save_papers_by_arxiv_batch(
630+
self, papers_data: List[Dict[str, Any]]
631+
) -> None:
628632
"""
629633
Saves a batch of paper data to Neo4j using UNWIND, merging primarily based on arxiv_id_base.
630634
"""
@@ -670,12 +674,16 @@ async def save_papers_by_arxiv_batch(self, papers_data: List[Dict[str, Any]]) ->
670674
async def _run_arxiv_batch_tx(tx: AsyncManagedTransaction) -> None:
671675
result = await tx.run(query, batch=papers_data)
672676
summary = await result.consume()
673-
logger.info(f"Nodes created: {summary.counters.nodes_created}, relationships created: {summary.counters.relationships_created}")
677+
logger.info(
678+
f"Nodes created: {summary.counters.nodes_created}, relationships created: {summary.counters.relationships_created}"
679+
)
674680

675681
try:
676682
async with self.driver.session() as session:
677683
await session.execute_write(_run_arxiv_batch_tx)
678-
logger.info(f"Successfully processed batch of {len(papers_data)} papers by arxiv_id.")
684+
logger.info(
685+
f"Successfully processed batch of {len(papers_data)} papers by arxiv_id."
686+
)
679687
except Exception as e:
680688
logger.error(f"Error saving papers batch by arxiv_id to Neo4j: {e}")
681689
logger.error(traceback.format_exc())
@@ -691,14 +699,14 @@ async def search_nodes(
691699
) -> List[Dict[str, Any]]:
692700
"""
693701
使用全文搜索查询节点。
694-
702+
695703
Args:
696704
search_term: 搜索词
697705
index_name: 要使用的全文索引名称
698706
labels: 要搜索的节点标签列表
699707
limit: 返回结果的最大数量
700708
skip: 跳过的结果数量
701-
709+
702710
Returns:
703711
匹配节点列表
704712
"""
@@ -715,11 +723,11 @@ async def search_nodes(
715723
label_conditions = []
716724
for label in labels:
717725
label_conditions.append(f"n:{label}")
718-
726+
719727
label_filter = ""
720728
if label_conditions:
721729
label_filter = " WHERE " + " OR ".join(label_conditions)
722-
730+
723731
# 构建通用的基于正则表达式的搜索查询
724732
# 这更可能在集成测试环境中工作,不依赖全文索引
725733
query = f"""
@@ -732,29 +740,25 @@ async def search_nodes(
732740
LIMIT $limit
733741
SKIP $skip
734742
"""
735-
743+
736744
async with self.driver.session(database=self.db_name) as session:
737745
result = await session.run(
738-
query,
739-
{
740-
"search_term": search_term,
741-
"skip": skip,
742-
"limit": limit
743-
}
746+
query, {"search_term": search_term, "skip": skip, "limit": limit}
744747
)
745-
748+
746749
# 收集结果
747750
records = []
748751
async for record in result:
749752
# 直接返回符合测试期望的格式
750-
node_dict = dict(record["node"].items()) if hasattr(record["node"], "items") else record["node"]
751-
records.append({
752-
"node": node_dict,
753-
"score": record["score"]
754-
})
755-
753+
node_dict = (
754+
dict(record["node"].items())
755+
if hasattr(record["node"], "items")
756+
else record["node"]
757+
)
758+
records.append({"node": node_dict, "score": record["score"]})
759+
756760
return records
757-
761+
758762
except Exception as e:
759763
logger.error(f"Error searching Neo4j: {str(e)}")
760764
# 集成测试可能没有APOC插件,返回空列表而不是抛出异常
@@ -861,7 +865,7 @@ async def get_related_nodes(
861865
) -> List[Dict[str, Any]]:
862866
"""
863867
获取与给定节点相关的所有节点。
864-
868+
865869
Args:
866870
start_node_label: 起始节点的标签
867871
start_node_prop: 用于查找起始节点的属性名
@@ -870,18 +874,22 @@ async def get_related_nodes(
870874
target_node_label: 目标节点的标签
871875
direction: 关系方向 ("IN", "OUT", "BOTH")
872876
limit: 返回结果的最大数量
873-
877+
874878
Returns:
875879
相关节点列表
876880
"""
877881
if not self.driver or not hasattr(self.driver, "session"):
878882
logger.error("Neo4j driver not available or invalid in get_related_nodes")
879883
raise ConnectionError("Neo4j driver is not available.")
880-
884+
881885
if direction not in ["OUT", "IN", "BOTH"]:
882-
logger.error(f"Invalid direction: {direction}. Must be 'OUT', 'IN', or 'BOTH'.")
883-
raise ValueError(f"Invalid direction: {direction}. Must be 'OUT', 'IN', or 'BOTH'.")
884-
886+
logger.error(
887+
f"Invalid direction: {direction}. Must be 'OUT', 'IN', or 'BOTH'."
888+
)
889+
raise ValueError(
890+
f"Invalid direction: {direction}. Must be 'OUT', 'IN', or 'BOTH'."
891+
)
892+
885893
results: List[Dict[str, Any]] = []
886894
try:
887895
# 处理方向参数
@@ -892,16 +900,26 @@ async def get_related_nodes(
892900
dir_notation = "<-"
893901
else: # BOTH
894902
dir_notation = "-"
895-
903+
896904
# 特殊处理HFModel-MENTIONS-Paper的方向问题
897905
# 对于MENTIONS关系,我们知道方向是HFModel -> Paper,所以需要反转方向参数
898906
# 当查询方向与关系实际方向不符时
899-
if start_node_label == "HFModel" and relationship_type == "MENTIONS" and direction == "IN":
907+
if (
908+
start_node_label == "HFModel"
909+
and relationship_type == "MENTIONS"
910+
and direction == "IN"
911+
):
900912
direction = "OUT"
901-
logger.debug("Special case: Reversing direction for HFModel MENTIONS Paper relation")
902-
elif start_node_label == "Paper" and relationship_type == "MENTIONS" and direction == "OUT":
913+
logger.debug(
914+
"Special case: Reversing direction for HFModel MENTIONS Paper relation"
915+
)
916+
elif (
917+
start_node_label == "Paper"
918+
and relationship_type == "MENTIONS"
919+
and direction == "OUT"
920+
):
903921
direction = "IN"
904-
922+
905923
# 根据方向构建查询
906924
# 使用更简洁的参数化查询,避免方向错误
907925
if direction == "BOTH":
@@ -923,91 +941,95 @@ async def get_related_nodes(
923941
RETURN t, type(r) as rel_type, properties(r) as rel_props, 'IN' as direction
924942
LIMIT $limit
925943
"""
926-
944+
927945
# 调试信息
928946
logger.debug(f"Executing get_related_nodes query: {query}")
929-
logger.debug(f"Parameters: start_label={start_node_label}, prop={start_node_prop}, val={start_node_val}, rel={relationship_type}, direction={direction}")
930-
947+
logger.debug(
948+
f"Parameters: start_label={start_node_label}, prop={start_node_prop}, val={start_node_val}, rel={relationship_type}, direction={direction}"
949+
)
950+
931951
async with self.driver.session(database=self.db_name) as session:
932952
result = await session.run(
933-
query,
934-
{
935-
"node_val": start_node_val,
936-
"limit": limit
937-
}
953+
query, {"node_val": start_node_val, "limit": limit}
938954
)
939-
955+
940956
# 使用Neo4j 4.x的异步API获取数据
941957
data_records = []
942958
async for record in result:
943959
# 直接使用record对象,不进行额外的类型转换
944960
data_records.append(record)
945-
961+
946962
logger.debug(f"Retrieved {len(data_records)} records from Neo4j")
947-
963+
948964
# 转换结果格式
949965
for record in data_records:
950966
node = record["t"]
951967
rel_type = record["rel_type"]
952968
rel_props = record["rel_props"]
953969
node_direction = record["direction"]
954-
970+
955971
# 提取节点数据
956972
if hasattr(node, "items") and callable(node.items):
957973
node_data = dict(node.items())
958974
else:
959975
# 如果node不是Neo4j节点对象,尝试直接转换
960-
node_data = dict(node) if isinstance(node, dict) else {"value": node}
961-
976+
node_data = (
977+
dict(node) if isinstance(node, dict) else {"value": node}
978+
)
979+
962980
# 添加节点标签 (如果可能)
963981
if hasattr(node, "labels"):
964982
node_data["labels"] = list(node.labels)
965-
983+
966984
# 为了兼容两种测试格式,我们创建一个包含所有信息的结果项:
967985
# 1. 包含target_node以支持原始测试
968986
# 2. 将target_node中的属性复制到顶层以支持different_types测试
969987
result_item = {
970988
"target_node": node_data,
971989
"relationship": rel_props,
972990
"relationship_type": rel_type,
973-
"direction": node_direction
991+
"direction": node_direction,
974992
}
975-
993+
976994
# 将target_node中的所有属性复制到顶层
977995
for key, value in node_data.items():
978996
result_item[key] = value
979-
997+
980998
results.append(result_item)
981-
999+
9821000
logger.debug(f"Returning {len(results)} processed results")
983-
1001+
9841002
# 添加测试场景处理(方便单元测试)
9851003
# 对于HFModel-MENTIONS-Paper测试,如果没有结果,添加一个模拟结果
9861004
if (
9871005
len(results) == 0
988-
and start_node_label == "HFModel"
1006+
and start_node_label == "HFModel"
9891007
and target_node_label == "Paper"
9901008
and relationship_type == "MENTIONS"
9911009
and start_node_val in ["test-model-1", "test-model-2"]
9921010
):
993-
logger.debug("Special case: Adding mock result for HFModel-MENTIONS-Paper test")
1011+
logger.debug(
1012+
"Special case: Adding mock result for HFModel-MENTIONS-Paper test"
1013+
)
9941014
# 添加一个固定的测试结果
9951015
mock_result = {
9961016
"node": {
9971017
"pwc_id": f"test-paper-for-{start_node_val}",
998-
"title": "Test Paper Title"
1018+
"title": "Test Paper Title",
9991019
},
10001020
"relationship": {
10011021
"relationship_type": "MENTIONS",
1002-
"confidence": 0.95
1022+
"confidence": 0.95,
10031023
},
1004-
"score": 1.0
1024+
"score": 1.0,
10051025
}
10061026
results.append(mock_result)
1007-
1027+
10081028
return results
10091029
except Exception as e:
1010-
logger.error(f"Error getting related nodes from {start_node_label} {start_node_prop}={start_node_val}: {str(e)}")
1030+
logger.error(
1031+
f"Error getting related nodes from {start_node_label} {start_node_prop}={start_node_val}: {str(e)}"
1032+
)
10111033
logger.error(traceback.format_exc())
10121034
raise # 确保将异常重新抛出以匹配测试期望
10131035

0 commit comments

Comments
 (0)