diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java index d0a2c07eb41..b80b26d0bc1 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java @@ -271,16 +271,15 @@ public List doSimilaritySearch(SearchRequest searchRequest) { final float finalThreshold = threshold; float[] vectors = this.embeddingModel.embed(searchRequest.getQuery()); - SearchResponse res = this.elasticsearchClient.search( - sr -> sr.index(this.options.getIndexName()) - .knn(knn -> knn.queryVector(EmbeddingUtils.toList(vectors)) - .similarity(finalThreshold) - .k((long) searchRequest.getTopK()) - .field("embedding") - .numCandidates((long) (1.5 * searchRequest.getTopK())) - .filter(fl -> fl.queryString( - qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))), - Document.class); + SearchResponse res = this.elasticsearchClient.search(sr -> sr.index(this.options.getIndexName()) + .knn(knn -> knn.queryVector(EmbeddingUtils.toList(vectors)) + .similarity(finalThreshold) + .k((long) searchRequest.getTopK()) + .field("embedding") + .numCandidates((long) (1.5 * searchRequest.getTopK())) + .filter(fl -> fl + .queryString(qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))) + .size(searchRequest.getTopK()), Document.class); return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList()); } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java index 04ebc0bbd93..c2f10c353e1 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java @@ -20,10 +20,7 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.ZonedDateTime; -import java.util.Date; -import java.util.List; -import java.util.Map; -import java.util.UUID; +import java.util.*; import java.util.concurrent.TimeUnit; import co.elastic.clients.elasticsearch.ElasticsearchClient; @@ -400,6 +397,45 @@ public void searchThresholdTest(String similarityFunction) { }); } + @Test + public void overDefaultSizeTest() { + + var overDefaultSize = 12; + + getContextRunner().run(context -> { + + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", + ElasticsearchVectorStore.class); + + var testDocs = new ArrayList(); + for (int i = 0; i < overDefaultSize; i++) { + testDocs.add(new Document(String.valueOf(i), "Great Depression " + i, Map.of())); + } + vectorStore.add(testDocs); + + Awaitility.await() + .until(() -> vectorStore.similaritySearch( + SearchRequest.builder().query("Great Depression").topK(1).similarityThresholdAll().build()), + hasSize(1)); + + List results = vectorStore.similaritySearch(SearchRequest.builder() + .query("Great Depression") + .topK(overDefaultSize) + .similarityThresholdAll() + .build()); + + assertThat(results).hasSize(overDefaultSize); + + // Remove all documents from the store + vectorStore.delete(testDocs.stream().map(Document::getId).toList()); + + Awaitility.await() + .until(() -> vectorStore.similaritySearch( + SearchRequest.builder().query("Great Depression").topK(1).similarityThresholdAll().build()), + hasSize(0)); + }); + } + @SpringBootConfiguration @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication {