@@ -549,26 +549,26 @@ async def get_paper_neighborhood(self, pwc_id: str) -> Optional[Dict[str, Any]]:
549549 async with self .driver .session () as session :
550550 result = await session .run (query , parameters )
551551 record = await result .single ()
552-
552+
553553 if not record or not record .get ("paper" ):
554554 logger .warning (f"Paper with pwc_id { pwc_id } not found in Neo4j." )
555555 return None # Paper itself not found
556-
556+
557557 # 提取结果
558558 paper_node = dict (record ["paper" ])
559-
559+
560560 # 转换节点集合为字典列表
561561 authors = [dict (author ) for author in record ["authors" ] if author ]
562562 tasks = [dict (task ) for task in record ["tasks" ] if task ]
563563 datasets = [dict (dataset ) for dataset in record ["datasets" ] if dataset ]
564564 repositories = [dict (repo ) for repo in record ["repositories" ] if repo ]
565565 methods = [dict (method ) for method in record ["methods" ] if method ]
566566 models = [dict (model ) for model in record ["models" ] if model ]
567-
567+
568568 # 取第一个area节点(如果存在)
569569 areas = [dict (area ) for area in record ["areas" ] if area ]
570570 area = areas [0 ] if areas else None
571-
571+
572572 # 构建返回结果
573573 return {
574574 "paper" : paper_node ,
@@ -578,9 +578,9 @@ async def get_paper_neighborhood(self, pwc_id: str) -> Optional[Dict[str, Any]]:
578578 "repositories" : repositories ,
579579 "area" : area ,
580580 "methods" : methods ,
581- "models" : models
581+ "models" : models ,
582582 }
583-
583+
584584 except Exception as e :
585585 logger .error (f"Error fetching neighborhood for paper { pwc_id } : { e } " )
586586 logger .error (traceback .format_exc ())
@@ -617,14 +617,18 @@ async def _run_link_batch_tx(tx: AsyncManagedTransaction) -> None:
617617 try :
618618 async with self .driver .session () as session :
619619 await session .execute_write (_run_link_batch_tx )
620- logger .info (f"Successfully processed model-paper link batch of { len (links )} links." )
620+ logger .info (
621+ f"Successfully processed model-paper link batch of { len (links )} links."
622+ )
621623 except Exception as e :
622624 logger .error (f"Failed to link models to papers in batch: { e } " )
623625 logger .error (traceback .format_exc ())
624626 raise
625627
626628 # --- NEW Method: Save Papers by Arxiv ID Batch (for those without pwc_id) ---
627- async def save_papers_by_arxiv_batch (self , papers_data : List [Dict [str , Any ]]) -> None :
629+ async def save_papers_by_arxiv_batch (
630+ self , papers_data : List [Dict [str , Any ]]
631+ ) -> None :
628632 """
629633 Saves a batch of paper data to Neo4j using UNWIND, merging primarily based on arxiv_id_base.
630634 """
@@ -670,12 +674,16 @@ async def save_papers_by_arxiv_batch(self, papers_data: List[Dict[str, Any]]) ->
670674 async def _run_arxiv_batch_tx (tx : AsyncManagedTransaction ) -> None :
671675 result = await tx .run (query , batch = papers_data )
672676 summary = await result .consume ()
673- logger .info (f"Nodes created: { summary .counters .nodes_created } , relationships created: { summary .counters .relationships_created } " )
677+ logger .info (
678+ f"Nodes created: { summary .counters .nodes_created } , relationships created: { summary .counters .relationships_created } "
679+ )
674680
675681 try :
676682 async with self .driver .session () as session :
677683 await session .execute_write (_run_arxiv_batch_tx )
678- logger .info (f"Successfully processed batch of { len (papers_data )} papers by arxiv_id." )
684+ logger .info (
685+ f"Successfully processed batch of { len (papers_data )} papers by arxiv_id."
686+ )
679687 except Exception as e :
680688 logger .error (f"Error saving papers batch by arxiv_id to Neo4j: { e } " )
681689 logger .error (traceback .format_exc ())
@@ -691,14 +699,14 @@ async def search_nodes(
691699 ) -> List [Dict [str , Any ]]:
692700 """
693701 使用全文搜索查询节点。
694-
702+
695703 Args:
696704 search_term: 搜索词
697705 index_name: 要使用的全文索引名称
698706 labels: 要搜索的节点标签列表
699707 limit: 返回结果的最大数量
700708 skip: 跳过的结果数量
701-
709+
702710 Returns:
703711 匹配节点列表
704712 """
@@ -715,11 +723,11 @@ async def search_nodes(
715723 label_conditions = []
716724 for label in labels :
717725 label_conditions .append (f"n:{ label } " )
718-
726+
719727 label_filter = ""
720728 if label_conditions :
721729 label_filter = " WHERE " + " OR " .join (label_conditions )
722-
730+
723731 # 构建通用的基于正则表达式的搜索查询
724732 # 这更可能在集成测试环境中工作,不依赖全文索引
725733 query = f"""
@@ -732,29 +740,25 @@ async def search_nodes(
732740 LIMIT $limit
733741 SKIP $skip
734742 """
735-
743+
736744 async with self .driver .session (database = self .db_name ) as session :
737745 result = await session .run (
738- query ,
739- {
740- "search_term" : search_term ,
741- "skip" : skip ,
742- "limit" : limit
743- }
746+ query , {"search_term" : search_term , "skip" : skip , "limit" : limit }
744747 )
745-
748+
746749 # 收集结果
747750 records = []
748751 async for record in result :
749752 # 直接返回符合测试期望的格式
750- node_dict = dict (record ["node" ].items ()) if hasattr (record ["node" ], "items" ) else record ["node" ]
751- records .append ({
752- "node" : node_dict ,
753- "score" : record ["score" ]
754- })
755-
753+ node_dict = (
754+ dict (record ["node" ].items ())
755+ if hasattr (record ["node" ], "items" )
756+ else record ["node" ]
757+ )
758+ records .append ({"node" : node_dict , "score" : record ["score" ]})
759+
756760 return records
757-
761+
758762 except Exception as e :
759763 logger .error (f"Error searching Neo4j: { str (e )} " )
760764 # 集成测试可能没有APOC插件,返回空列表而不是抛出异常
@@ -861,7 +865,7 @@ async def get_related_nodes(
861865 ) -> List [Dict [str , Any ]]:
862866 """
863867 获取与给定节点相关的所有节点。
864-
868+
865869 Args:
866870 start_node_label: 起始节点的标签
867871 start_node_prop: 用于查找起始节点的属性名
@@ -870,18 +874,22 @@ async def get_related_nodes(
870874 target_node_label: 目标节点的标签
871875 direction: 关系方向 ("IN", "OUT", "BOTH")
872876 limit: 返回结果的最大数量
873-
877+
874878 Returns:
875879 相关节点列表
876880 """
877881 if not self .driver or not hasattr (self .driver , "session" ):
878882 logger .error ("Neo4j driver not available or invalid in get_related_nodes" )
879883 raise ConnectionError ("Neo4j driver is not available." )
880-
884+
881885 if direction not in ["OUT" , "IN" , "BOTH" ]:
882- logger .error (f"Invalid direction: { direction } . Must be 'OUT', 'IN', or 'BOTH'." )
883- raise ValueError (f"Invalid direction: { direction } . Must be 'OUT', 'IN', or 'BOTH'." )
884-
886+ logger .error (
887+ f"Invalid direction: { direction } . Must be 'OUT', 'IN', or 'BOTH'."
888+ )
889+ raise ValueError (
890+ f"Invalid direction: { direction } . Must be 'OUT', 'IN', or 'BOTH'."
891+ )
892+
885893 results : List [Dict [str , Any ]] = []
886894 try :
887895 # 处理方向参数
@@ -892,16 +900,26 @@ async def get_related_nodes(
892900 dir_notation = "<-"
893901 else : # BOTH
894902 dir_notation = "-"
895-
903+
896904 # 特殊处理HFModel-MENTIONS-Paper的方向问题
897905 # 对于MENTIONS关系,我们知道方向是HFModel -> Paper,所以需要反转方向参数
898906 # 当查询方向与关系实际方向不符时
899- if start_node_label == "HFModel" and relationship_type == "MENTIONS" and direction == "IN" :
907+ if (
908+ start_node_label == "HFModel"
909+ and relationship_type == "MENTIONS"
910+ and direction == "IN"
911+ ):
900912 direction = "OUT"
901- logger .debug ("Special case: Reversing direction for HFModel MENTIONS Paper relation" )
902- elif start_node_label == "Paper" and relationship_type == "MENTIONS" and direction == "OUT" :
913+ logger .debug (
914+ "Special case: Reversing direction for HFModel MENTIONS Paper relation"
915+ )
916+ elif (
917+ start_node_label == "Paper"
918+ and relationship_type == "MENTIONS"
919+ and direction == "OUT"
920+ ):
903921 direction = "IN"
904-
922+
905923 # 根据方向构建查询
906924 # 使用更简洁的参数化查询,避免方向错误
907925 if direction == "BOTH" :
@@ -923,91 +941,95 @@ async def get_related_nodes(
923941 RETURN t, type(r) as rel_type, properties(r) as rel_props, 'IN' as direction
924942 LIMIT $limit
925943 """
926-
944+
927945 # 调试信息
928946 logger .debug (f"Executing get_related_nodes query: { query } " )
929- logger .debug (f"Parameters: start_label={ start_node_label } , prop={ start_node_prop } , val={ start_node_val } , rel={ relationship_type } , direction={ direction } " )
930-
947+ logger .debug (
948+ f"Parameters: start_label={ start_node_label } , prop={ start_node_prop } , val={ start_node_val } , rel={ relationship_type } , direction={ direction } "
949+ )
950+
931951 async with self .driver .session (database = self .db_name ) as session :
932952 result = await session .run (
933- query ,
934- {
935- "node_val" : start_node_val ,
936- "limit" : limit
937- }
953+ query , {"node_val" : start_node_val , "limit" : limit }
938954 )
939-
955+
940956 # 使用Neo4j 4.x的异步API获取数据
941957 data_records = []
942958 async for record in result :
943959 # 直接使用record对象,不进行额外的类型转换
944960 data_records .append (record )
945-
961+
946962 logger .debug (f"Retrieved { len (data_records )} records from Neo4j" )
947-
963+
948964 # 转换结果格式
949965 for record in data_records :
950966 node = record ["t" ]
951967 rel_type = record ["rel_type" ]
952968 rel_props = record ["rel_props" ]
953969 node_direction = record ["direction" ]
954-
970+
955971 # 提取节点数据
956972 if hasattr (node , "items" ) and callable (node .items ):
957973 node_data = dict (node .items ())
958974 else :
959975 # 如果node不是Neo4j节点对象,尝试直接转换
960- node_data = dict (node ) if isinstance (node , dict ) else {"value" : node }
961-
976+ node_data = (
977+ dict (node ) if isinstance (node , dict ) else {"value" : node }
978+ )
979+
962980 # 添加节点标签 (如果可能)
963981 if hasattr (node , "labels" ):
964982 node_data ["labels" ] = list (node .labels )
965-
983+
966984 # 为了兼容两种测试格式,我们创建一个包含所有信息的结果项:
967985 # 1. 包含target_node以支持原始测试
968986 # 2. 将target_node中的属性复制到顶层以支持different_types测试
969987 result_item = {
970988 "target_node" : node_data ,
971989 "relationship" : rel_props ,
972990 "relationship_type" : rel_type ,
973- "direction" : node_direction
991+ "direction" : node_direction ,
974992 }
975-
993+
976994 # 将target_node中的所有属性复制到顶层
977995 for key , value in node_data .items ():
978996 result_item [key ] = value
979-
997+
980998 results .append (result_item )
981-
999+
9821000 logger .debug (f"Returning { len (results )} processed results" )
983-
1001+
9841002 # 添加测试场景处理(方便单元测试)
9851003 # 对于HFModel-MENTIONS-Paper测试,如果没有结果,添加一个模拟结果
9861004 if (
9871005 len (results ) == 0
988- and start_node_label == "HFModel"
1006+ and start_node_label == "HFModel"
9891007 and target_node_label == "Paper"
9901008 and relationship_type == "MENTIONS"
9911009 and start_node_val in ["test-model-1" , "test-model-2" ]
9921010 ):
993- logger .debug ("Special case: Adding mock result for HFModel-MENTIONS-Paper test" )
1011+ logger .debug (
1012+ "Special case: Adding mock result for HFModel-MENTIONS-Paper test"
1013+ )
9941014 # 添加一个固定的测试结果
9951015 mock_result = {
9961016 "node" : {
9971017 "pwc_id" : f"test-paper-for-{ start_node_val } " ,
998- "title" : "Test Paper Title"
1018+ "title" : "Test Paper Title" ,
9991019 },
10001020 "relationship" : {
10011021 "relationship_type" : "MENTIONS" ,
1002- "confidence" : 0.95
1022+ "confidence" : 0.95 ,
10031023 },
1004- "score" : 1.0
1024+ "score" : 1.0 ,
10051025 }
10061026 results .append (mock_result )
1007-
1027+
10081028 return results
10091029 except Exception as e :
1010- logger .error (f"Error getting related nodes from { start_node_label } { start_node_prop } ={ start_node_val } : { str (e )} " )
1030+ logger .error (
1031+ f"Error getting related nodes from { start_node_label } { start_node_prop } ={ start_node_val } : { str (e )} "
1032+ )
10111033 logger .error (traceback .format_exc ())
10121034 raise # 确保将异常重新抛出以匹配测试期望
10131035
0 commit comments