Skip to content

Commit 673e69f

Browse files
committed
Add a rerank abstract class and CohereRerank which uses Cohere Reranking
1 parent 7ec971f commit 673e69f

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

ontoaligner/base/rerank.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Dict, Any
3+
from tqdm import tqdm
4+
5+
class Rerank(ABC):
6+
"""
7+
Abstract base class for reranking ontology alignment candidates.
8+
9+
Subclasses must implement `rerank_candidates` which, given a source text
10+
and a list of candidate target texts, returns them reranked with scores.
11+
"""
12+
13+
def __init__(self, top_n: int = 1):
14+
"""
15+
Args:
16+
top_n: Number of top candidates to return per source after reranking.
17+
"""
18+
self.top_n = top_n
19+
20+
@abstractmethod
21+
def rerank_candidates(
22+
self, query: str, documents: List[str]
23+
) -> List[Dict[str, Any]]:
24+
"""
25+
Rerank a list of candidate documents against a query.
26+
27+
Args:
28+
query: The source concept text.
29+
documents: List of candidate target concept texts.
30+
31+
Returns:
32+
List of dicts with keys:
33+
- "index": original index in the documents list
34+
- "relevance_score": reranking score
35+
"""
36+
pass
37+
38+
def rerank_retrieval_outputs(
39+
self,
40+
retrieval_outputs: List[Dict],
41+
source_iri2text: Dict[str, str],
42+
target_iri2text: Dict[str, str],
43+
) -> List[Dict[str, Any]]:
44+
"""
45+
Rerank all retrieval outputs and return flat source-target-score predictions.
46+
47+
Args:
48+
retrieval_outputs: BM25 outputs with "source", "target-cands", "score-cands".
49+
source_iri2text: Mapping from source IRI to text label.
50+
target_iri2text: Mapping from target IRI to text label.
51+
52+
Returns:
53+
List of {"source": iri, "target": iri, "score": float}
54+
"""
55+
predictions = []
56+
for entry in tqdm(retrieval_outputs, desc="Reranking"):
57+
source_iri = entry["source"]
58+
candidate_iris = entry["target-cands"]
59+
60+
query_text = source_iri2text.get(source_iri, "")
61+
doc_texts = [target_iri2text.get(iri, "") for iri in candidate_iris]
62+
63+
if not query_text or not doc_texts:
64+
continue
65+
66+
ranked = self.rerank_candidates(query=query_text, documents=doc_texts)
67+
68+
for item in ranked:
69+
idx = item["index"]
70+
predictions.append({
71+
"source": source_iri,
72+
"target": candidate_iris[idx],
73+
"score": item["relevance_score"],
74+
})
75+
76+
return predictions
77+
78+
class CohereRerank(Rerank):
79+
"""
80+
Reranker using the Cohere Rerank API.
81+
82+
See: https://docs.cohere.com/reference/rerank
83+
"""
84+
85+
def __init__(self, api_key: str, model: str = "rerank-v3.5", top_n: int = 1):
86+
"""
87+
Args:
88+
api_key: Cohere API key.
89+
model: Cohere rerank model name.
90+
top_n: Number of top candidates to return per query.
91+
"""
92+
super().__init__(top_n=top_n)
93+
import cohere
94+
self.client = cohere.ClientV2(api_key)
95+
self.model = model
96+
97+
def rerank_candidates(
98+
self, query: str, documents: List[str]
99+
) -> List[Dict[str, Any]]:
100+
"""
101+
Calls the Cohere Rerank API to rerank candidate documents.
102+
103+
Args:
104+
query: The source concept text.
105+
documents: List of candidate target concept texts.
106+
107+
Returns:
108+
List of dicts with "index" and "relevance_score",
109+
limited to self.top_n results.
110+
"""
111+
response = self.client.rerank(
112+
model=self.model,
113+
query=query,
114+
documents=documents,
115+
top_n=self.top_n,
116+
)
117+
return [
118+
{"index": r.index, "relevance_score": r.relevance_score}
119+
for r in response.results
120+
]

0 commit comments

Comments
 (0)