From 9b51b78e58a136fd8ce40176f65b6b5fc97e0ffc Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Sun, 18 May 2025 17:56:09 +0200 Subject: [PATCH] RAG - Document joiner should sort by score The current implementation of ConcatenationDocumentJoiner should sort the final document list by score in descending order, so to keep the most relevant documents at the front of the list. This PR fixes that. Signed-off-by: Thomas Vitale --- .../join/ConcatenationDocumentJoiner.java | 12 +++++++--- .../ConcatenationDocumentJoinerTests.java | 24 ++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java index 56038587fa6..2f9032c9a9a 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.ai.rag.retrieval.join; import java.util.ArrayList; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.function.Function; @@ -32,7 +33,8 @@ /** * Combines documents retrieved based on multiple queries and from multiple data sources * by concatenating them into a single collection of documents. In case of duplicate - * documents, the first occurrence is kept. The score of each document is kept as is. + * documents, the first occurrence is kept. The score of each document is kept as is. The + * result is a list of unique documents sorted by their score in descending order. * * @author Thomas Vitale * @since 1.0.0 @@ -54,7 +56,11 @@ public List join(Map>> documentsForQuery) { .flatMap(List::stream) .flatMap(List::stream) .collect(Collectors.toMap(Document::getId, Function.identity(), (existing, duplicate) -> existing)) - .values()); + .values() + .stream() + .sorted(Comparator.comparingDouble((Document doc) -> doc.getScore() != null ? doc.getScore() : 0.0) + .reversed()) + .toList()); } } diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java index 39a588555e6..b7035b1e39a 100644 --- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java +++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java @@ -21,7 +21,6 @@ import java.util.Map; import org.junit.jupiter.api.Test; - import org.springframework.ai.document.Document; import org.springframework.ai.rag.Query; @@ -92,4 +91,27 @@ void whenDuplicatedDocumentsThenOnlyFirstOccurrenceIsKept() { assertThat(result).extracting(Document::getText).containsOnlyOnce("Content 2"); } + @Test + void shouldSortDocumentsByDescendingScore() { + //@formatter:off + DocumentJoiner documentJoiner = new ConcatenationDocumentJoiner(); + var documentsForQuery = new HashMap>>(); + documentsForQuery.put(new Query("query1"), List.of( + List.of( + Document.builder().id("1").text("Content 1").score(0.81).build(), + Document.builder().id("2").text("Content 2").score(0.83).build()), + List.of( + Document.builder().id("3").text("Content 3").score(null).build()))); + documentsForQuery.put(new Query("query2"), List.of( + List.of( + Document.builder().id("4").text("Content 4").score(0.85).build(), + Document.builder().id("5").text("Content 5").score(0.77).build()))); + + List result = documentJoiner.join(documentsForQuery); + + assertThat(result).hasSize(5); + assertThat(result).extracting(Document::getId).containsExactly("4", "2", "1", "5", "3"); + //@formatter:on + } + }