Skip to content

Commit 112813f

Browse files
committed
refact: sort by descending score In ConcatenationDocumentJoiner
Signed-off-by: ghdcksgml1 <[email protected]>
1 parent e4357ba commit 112813f

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoiner.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.rag.retrieval.join;
1818

1919
import java.util.ArrayList;
20+
import java.util.Comparator;
2021
import java.util.List;
2122
import java.util.Map;
2223
import java.util.function.Function;
@@ -34,7 +35,7 @@
3435
* by concatenating them into a single collection of documents. In case of duplicate
3536
* documents, the first occurrence is kept. The score of each document is kept as is.
3637
*
37-
* @author Thomas Vitale
38+
* @author Thomas Vitale, ghdcksgml1
3839
* @since 1.0.0
3940
*/
4041
public class ConcatenationDocumentJoiner implements DocumentJoiner {
@@ -54,7 +55,10 @@ public List<Document> join(Map<Query, List<List<Document>>> documentsForQuery) {
5455
.flatMap(List::stream)
5556
.flatMap(List::stream)
5657
.collect(Collectors.toMap(Document::getId, Function.identity(), (existing, duplicate) -> existing))
57-
.values());
58+
.values()
59+
.stream()
60+
.sorted(Comparator.comparing((Document d1) -> (d1.getScore() != null) ? d1.getScore() : 0.0).reversed())
61+
.toList());
5862
}
5963

6064
}

spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
/**
3232
* Unit tests for {@link ConcatenationDocumentJoiner}.
3333
*
34-
* @author Thomas Vitale
34+
* @author Thomas Vitale, ghdcksgml1
3535
*/
3636
class ConcatenationDocumentJoinerTests {
3737

@@ -92,4 +92,22 @@ void whenDuplicatedDocumentsThenOnlyFirstOccurrenceIsKept() {
9292
assertThat(result).extracting(Document::getText).containsOnlyOnce("Content 2");
9393
}
9494

95+
@Test
96+
void whenSeveralQueryExistsInMapThenDocumentsAreJoinedInDescendingScoreOrder() {
97+
DocumentJoiner documentJoiner = new ConcatenationDocumentJoiner();
98+
var documentsForQuery = new HashMap<Query, List<List<Document>>>();
99+
documentsForQuery.put(new Query("query1"),
100+
List.of(List.of(Document.builder().id("1").text("Content 1").score(0.9).build(),
101+
Document.builder().id("4").text("Content 4").score(0.6).build()),
102+
List.of(Document.builder().id("2").text("Content 2").score(0.8).build())));
103+
documentsForQuery.put(new Query("query2"),
104+
List.of(List.of(Document.builder().id("3").text("Content 3").score(0.7).build())));
105+
106+
List<Document> result = documentJoiner.join(documentsForQuery);
107+
108+
assertThat(result).hasSize(4);
109+
assertThat(result).extracting(Document::getId).containsExactly("1", "2", "3", "4");
110+
assertThat(result).extracting(Document::getScore).containsExactly(0.9, 0.8, 0.7, 0.6);
111+
}
112+
95113
}

0 commit comments

Comments
 (0)