@@ -66,7 +66,16 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
6666 taxonomic_pairs = [{"parent" : candidate , "child" : query }
6767 for query , candidates in zip (data , candidates_lst )
6868 for candidate in candidates if candidate .lower () != query .lower ()]
69- return taxonomic_pairs
69+ taxonomic_pairs += [{"parent" : query , "child" : candidate }
70+ for query , candidates in zip (data , candidates_lst )
71+ for candidate in candidates if candidate .lower () != query .lower ()]
72+ unique_taxonomic_pairs , seen = [], set ()
73+ for pair in taxonomic_pairs :
74+ key = (pair ["parent" ].lower (), pair ["child" ].lower ()) # Directional key (parent, child)
75+ if key not in seen :
76+ seen .add (key )
77+ unique_taxonomic_pairs .append (pair )
78+ return unique_taxonomic_pairs
7079 else :
7180 warnings .warn ("No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well." )
7281
@@ -86,11 +95,23 @@ def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
8695 candidates_lst = self ._retriever_predict (data = data ['types' ], top_k = self .top_k + 1 )
8796 taxonomic_pairs = []
8897 taxonomic_pairs_query = []
98+ seen = set ()
8999 for query , candidates in zip (data ['types' ], candidates_lst ):
90100 for candidate in candidates :
91101 if candidate != query :
92- taxonomic_pairs .append ((query , candidate ))
93- taxonomic_pairs_query .append (f"Head: { query } \n Tail: { candidate } " )
102+ # Directional pair 1: query -> candidate
103+ key1 = (query .lower (), candidate .lower ())
104+ if key1 not in seen :
105+ seen .add (key1 )
106+ taxonomic_pairs .append ((query , candidate ))
107+ taxonomic_pairs_query .append (f"Head: { query } \n Tail: { candidate } " )
108+ # Directional pair 2: candidate -> query
109+ key2 = (candidate .lower (), query .lower ())
110+ if key2 not in seen :
111+ seen .add (key2 )
112+ taxonomic_pairs .append ((candidate , query ))
113+ taxonomic_pairs_query .append (f"Head: { candidate } \n Tail: { query } " )
114+
94115 self ._retriever_fit (data = data ['relations' ])
95116 candidate_relations_lst = self ._retriever_predict (data = taxonomic_pairs_query , top_k = self .top_k )
96117 non_taxonomic_re = [{"head" : head , "tail" : tail , "relation" : relation }
0 commit comments