|
10 | 10 | import org.slf4j.Logger; |
11 | 11 | import org.slf4j.LoggerFactory; |
12 | 12 | import org.springframework.ai.document.Document; |
13 | | -import org.springframework.http.HttpHeaders; |
14 | 13 | import org.springframework.web.reactive.function.client.WebClient; |
| 14 | +import org.springframework.http.HttpHeaders; |
15 | 15 |
|
16 | 16 | /** |
17 | | - * A Reranker implementation that integrates with Cohere's Rerank API. |
18 | | - * This component reorders retrieved documents based on semantic relevance to the input query. |
| 17 | + * A Reranker implementation that integrates with Cohere's Rerank API. This component |
| 18 | + * reorders retrieved documents based on semantic relevance to the input query. |
19 | 19 | * |
20 | 20 | * @author KoreaNirsa |
21 | | - * @see <a href="https://docs.cohere.com/reference/rerank">Cohere Rerank API Documentation</a> |
| 21 | + * @see <a href="https://docs.cohere.com/reference/rerank">Cohere Rerank API |
| 22 | + * Documentation</a> |
22 | 23 | */ |
23 | 24 | public class CohereReranker { |
| 25 | + |
24 | 26 | private static final String COHERE_RERANK_ENDPOINT = "https://api.cohere.ai/v1/rerank"; |
25 | 27 |
|
26 | 28 | private static final Logger logger = LoggerFactory.getLogger(CohereReranker.class); |
27 | | - |
| 29 | + |
28 | 30 | private static final int MAX_DOCUMENTS = 1000; |
29 | 31 |
|
30 | 32 | private final WebClient webClient; |
31 | 33 |
|
32 | 34 | /** |
33 | 35 | * Constructs a CohereReranker that communicates with the Cohere Rerank API. |
34 | 36 | * Initializes the internal WebClient with the provided API key for authorization. |
35 | | - * |
36 | | - * @param cohereApi the API configuration object containing the required API key (must not be null) |
| 37 | + * @param cohereApi the API configuration object containing the required API key (must |
| 38 | + * not be null) |
37 | 39 | * @throws IllegalArgumentException if cohereApi is null |
38 | 40 | */ |
39 | | - CohereReranker(CohereApi cohereApi) { |
40 | | - if (cohereApi == null) { |
41 | | - throw new IllegalArgumentException("CohereApi must not be null"); |
42 | | - } |
43 | | - |
44 | | - this.webClient = WebClient.builder() |
45 | | - .baseUrl(COHERE_RERANK_ENDPOINT) |
46 | | - .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey()) |
47 | | - .build(); |
48 | | - } |
49 | | - |
50 | | - /** |
51 | | - * Reranks a list of documents based on the provided query using the Cohere API. |
52 | | - * |
53 | | - * @param query The user input query. |
54 | | - * @param documents The list of documents to rerank. |
55 | | - * @param topN The number of top results to return (at most). |
56 | | - * @return A reranked list of documents. If the API fails, returns the original list. |
57 | | - */ |
58 | | - public List<Document> rerank(String query, List<Document> documents, int topN) { |
59 | | - if (topN < 1) { |
60 | | - throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN); |
61 | | - } |
62 | | - |
63 | | - if (documents == null || documents.isEmpty()) { |
64 | | - logger.warn("Empty document list provided. Skipping rerank."); |
65 | | - return Collections.emptyList(); |
66 | | - } |
67 | | - |
68 | | - if (documents.size() > MAX_DOCUMENTS) { |
69 | | - logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", MAX_DOCUMENTS); |
70 | | - return documents; |
71 | | - } |
72 | | - |
73 | | - int adjustedTopN = Math.min(topN, documents.size()); |
74 | | - |
75 | | - Map<String, Object> payload = Map.of( |
76 | | - "query", query, |
77 | | - "documents", documents.stream().map(Document::getText).toList(), |
78 | | - "top_n", adjustedTopN |
79 | | - ); |
80 | | - |
81 | | - // Call the API and process the result |
82 | | - return sendRerankRequest(payload) |
83 | | - .map(results -> results.stream() |
84 | | - .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) |
85 | | - .map(r -> { |
86 | | - Document original = documents.get(r.getIndex()); |
87 | | - Map<String, Object> metadata = new HashMap<>(original.getMetadata()); |
88 | | - metadata.put("score", String.format("%.4f", r.getRelevanceScore())); |
89 | | - return new Document(original.getText(), metadata); |
90 | | - }) |
91 | | - .toList()) |
92 | | - .orElseGet(() -> { |
93 | | - logger.warn("Cohere response is null or invalid"); |
94 | | - return documents; |
95 | | - }); |
96 | | - } |
97 | | - |
98 | | - /** |
99 | | - * Sends a rerank request to the Cohere API and returns the result list. |
100 | | - * |
101 | | - * @param payload The request body including query, documents, and top_n. |
102 | | - * @return An Optional list of reranked results, or empty if failed. |
103 | | - */ |
104 | | - private Optional<List<RerankResponse.Result>> sendRerankRequest(Map<String, Object> payload) { |
105 | | - try { |
106 | | - RerankResponse response = webClient.post() |
107 | | - .bodyValue(payload) |
108 | | - .retrieve() |
109 | | - .bodyToMono(RerankResponse.class) |
110 | | - .block(); |
111 | | - |
112 | | - return Optional.ofNullable(response) |
113 | | - .map(RerankResponse::getResults); |
114 | | - } catch (Exception e) { |
115 | | - logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e); |
116 | | - return Optional.empty(); |
117 | | - } |
118 | | - } |
| 41 | + CohereReranker(CohereApi cohereApi) { |
| 42 | + if (cohereApi == null) { |
| 43 | + throw new IllegalArgumentException("CohereApi must not be null"); |
| 44 | + } |
| 45 | + |
| 46 | + this.webClient = WebClient.builder() |
| 47 | + .baseUrl(COHERE_RERANK_ENDPOINT) |
| 48 | + .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey()) |
| 49 | + .build(); |
| 50 | + } |
| 51 | + |
| 52 | + /** |
| 53 | + * Reranks a list of documents based on the provided query using the Cohere API. |
| 54 | + * @param query The user input query. |
| 55 | + * @param documents The list of documents to rerank. |
| 56 | + * @param topN The number of top results to return (at most). |
| 57 | + * @return A reranked list of documents. If the API fails, returns the original list. |
| 58 | + */ |
| 59 | + public List<Document> rerank(String query, List<Document> documents, int topN) { |
| 60 | + if (topN < 1) { |
| 61 | + throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN); |
| 62 | + } |
| 63 | + |
| 64 | + if (documents == null || documents.isEmpty()) { |
| 65 | + logger.warn("Empty document list provided. Skipping rerank."); |
| 66 | + return Collections.emptyList(); |
| 67 | + } |
| 68 | + |
| 69 | + if (documents.size() > MAX_DOCUMENTS) { |
| 70 | + logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", |
| 71 | + MAX_DOCUMENTS); |
| 72 | + return documents; |
| 73 | + } |
| 74 | + |
| 75 | + int adjustedTopN = Math.min(topN, documents.size()); |
| 76 | + |
| 77 | + Map<String, Object> payload = Map.of("query", query, "documents", |
| 78 | + documents.stream().map(Document::getText).toList(), "top_n", adjustedTopN); |
| 79 | + |
| 80 | + // Call the API and process the result |
| 81 | + return sendRerankRequest(payload).map(results -> results.stream() |
| 82 | + .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) |
| 83 | + .map(r -> { |
| 84 | + Document original = documents.get(r.getIndex()); |
| 85 | + Map<String, Object> metadata = new HashMap<>(original.getMetadata()); |
| 86 | + metadata.put("score", String.format("%.4f", r.getRelevanceScore())); |
| 87 | + return new Document(original.getText(), metadata); |
| 88 | + }) |
| 89 | + .toList()).orElseGet(() -> { |
| 90 | + logger.warn("Cohere response is null or invalid"); |
| 91 | + return documents; |
| 92 | + }); |
| 93 | + } |
| 94 | + |
| 95 | + /** |
| 96 | + * Sends a rerank request to the Cohere API and returns the result list. |
| 97 | + * @param payload The request body including query, documents, and top_n. |
| 98 | + * @return An Optional list of reranked results, or empty if failed. |
| 99 | + */ |
| 100 | + private Optional<List<RerankResponse.Result>> sendRerankRequest(Map<String, Object> payload) { |
| 101 | + try { |
| 102 | + RerankResponse response = webClient.post() |
| 103 | + .bodyValue(payload) |
| 104 | + .retrieve() |
| 105 | + .bodyToMono(RerankResponse.class) |
| 106 | + .block(); |
| 107 | + |
| 108 | + return Optional.ofNullable(response).map(RerankResponse::getResults); |
| 109 | + } |
| 110 | + catch (Exception e) { |
| 111 | + logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e); |
| 112 | + return Optional.empty(); |
| 113 | + } |
| 114 | + } |
| 115 | + |
119 | 116 | } |
0 commit comments