1+ from abc import ABC , abstractmethod
2+ from typing import List , Dict , Any
3+ from tqdm import tqdm
4+
5+ class Rerank (ABC ):
6+ """
7+ Abstract base class for reranking ontology alignment candidates.
8+
9+ Subclasses must implement `rerank_candidates` which, given a source text
10+ and a list of candidate target texts, returns them reranked with scores.
11+ """
12+
13+ def __init__ (self , top_n : int = 1 ):
14+ """
15+ Args:
16+ top_n: Number of top candidates to return per source after reranking.
17+ """
18+ self .top_n = top_n
19+
20+ @abstractmethod
21+ def rerank_candidates (
22+ self , query : str , documents : List [str ]
23+ ) -> List [Dict [str , Any ]]:
24+ """
25+ Rerank a list of candidate documents against a query.
26+
27+ Args:
28+ query: The source concept text.
29+ documents: List of candidate target concept texts.
30+
31+ Returns:
32+ List of dicts with keys:
33+ - "index": original index in the documents list
34+ - "relevance_score": reranking score
35+ """
36+ pass
37+
38+ def rerank_retrieval_outputs (
39+ self ,
40+ retrieval_outputs : List [Dict ],
41+ source_iri2text : Dict [str , str ],
42+ target_iri2text : Dict [str , str ],
43+ ) -> List [Dict [str , Any ]]:
44+ """
45+ Rerank all retrieval outputs and return flat source-target-score predictions.
46+
47+ Args:
48+ retrieval_outputs: BM25 outputs with "source", "target-cands", "score-cands".
49+ source_iri2text: Mapping from source IRI to text label.
50+ target_iri2text: Mapping from target IRI to text label.
51+
52+ Returns:
53+ List of {"source": iri, "target": iri, "score": float}
54+ """
55+ predictions = []
56+ for entry in tqdm (retrieval_outputs , desc = "Reranking" ):
57+ source_iri = entry ["source" ]
58+ candidate_iris = entry ["target-cands" ]
59+
60+ query_text = source_iri2text .get (source_iri , "" )
61+ doc_texts = [target_iri2text .get (iri , "" ) for iri in candidate_iris ]
62+
63+ if not query_text or not doc_texts :
64+ continue
65+
66+ ranked = self .rerank_candidates (query = query_text , documents = doc_texts )
67+
68+ for item in ranked :
69+ idx = item ["index" ]
70+ predictions .append ({
71+ "source" : source_iri ,
72+ "target" : candidate_iris [idx ],
73+ "score" : item ["relevance_score" ],
74+ })
75+
76+ return predictions
77+
78+ class CohereRerank (Rerank ):
79+ """
80+ Reranker using the Cohere Rerank API.
81+
82+ See: https://docs.cohere.com/reference/rerank
83+ """
84+
85+ def __init__ (self , api_key : str , model : str = "rerank-v3.5" , top_n : int = 1 ):
86+ """
87+ Args:
88+ api_key: Cohere API key.
89+ model: Cohere rerank model name.
90+ top_n: Number of top candidates to return per query.
91+ """
92+ super ().__init__ (top_n = top_n )
93+ import cohere
94+ self .client = cohere .ClientV2 (api_key )
95+ self .model = model
96+
97+ def rerank_candidates (
98+ self , query : str , documents : List [str ]
99+ ) -> List [Dict [str , Any ]]:
100+ """
101+ Calls the Cohere Rerank API to rerank candidate documents.
102+
103+ Args:
104+ query: The source concept text.
105+ documents: List of candidate target concept texts.
106+
107+ Returns:
108+ List of dicts with "index" and "relevance_score",
109+ limited to self.top_n results.
110+ """
111+ response = self .client .rerank (
112+ model = self .model ,
113+ query = query ,
114+ documents = documents ,
115+ top_n = self .top_n ,
116+ )
117+ return [
118+ {"index" : r .index , "relevance_score" : r .relevance_score }
119+ for r in response .results
120+ ]
0 commit comments