Skip to content

Commit 95675a8

Browse files
dafrizilayaperumalg
authored andcommitted
Set ElasticSearch size to match requested topK used in KNN search
1 parent cd5684a commit 95675a8

File tree

2 files changed

+49
-10
lines changed

2 files changed

+49
-10
lines changed

vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,15 @@ public List<Document> doSimilaritySearch(SearchRequest searchRequest) {
242242
final float finalThreshold = threshold;
243243
float[] vectors = this.embeddingModel.embed(searchRequest.getQuery());
244244

245-
SearchResponse<Document> res = this.elasticsearchClient.search(
246-
sr -> sr.index(this.options.getIndexName())
247-
.knn(knn -> knn.queryVector(EmbeddingUtils.toList(vectors))
248-
.similarity(finalThreshold)
249-
.k((long) searchRequest.getTopK())
250-
.field("embedding")
251-
.numCandidates((long) (1.5 * searchRequest.getTopK()))
252-
.filter(fl -> fl.queryString(
253-
qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))),
254-
Document.class);
245+
SearchResponse<Document> res = this.elasticsearchClient.search(sr -> sr.index(this.options.getIndexName())
246+
.knn(knn -> knn.queryVector(EmbeddingUtils.toList(vectors))
247+
.similarity(finalThreshold)
248+
.k((long) searchRequest.getTopK())
249+
.field("embedding")
250+
.numCandidates((long) (1.5 * searchRequest.getTopK()))
251+
.filter(fl -> fl
252+
.queryString(qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression())))))
253+
.size(searchRequest.getTopK()), Document.class);
255254

256255
return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList());
257256
}

vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.nio.charset.StandardCharsets;
2121
import java.time.Duration;
2222
import java.time.ZonedDateTime;
23+
import java.util.ArrayList;
2324
import java.util.Date;
2425
import java.util.List;
2526
import java.util.Map;
@@ -400,6 +401,45 @@ public void searchThresholdTest(String similarityFunction) {
400401
});
401402
}
402403

404+
@Test
405+
public void overDefaultSizeTest() {
406+
407+
var overDefaultSize = 12;
408+
409+
getContextRunner().run(context -> {
410+
411+
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine",
412+
ElasticsearchVectorStore.class);
413+
414+
var testDocs = new ArrayList<Document>();
415+
for (int i = 0; i < overDefaultSize; i++) {
416+
testDocs.add(new Document(String.valueOf(i), "Great Depression " + i, Map.of()));
417+
}
418+
vectorStore.add(testDocs);
419+
420+
Awaitility.await()
421+
.until(() -> vectorStore.similaritySearch(
422+
SearchRequest.builder().query("Great Depression").topK(1).similarityThresholdAll().build()),
423+
hasSize(1));
424+
425+
List<Document> results = vectorStore.similaritySearch(SearchRequest.builder()
426+
.query("Great Depression")
427+
.topK(overDefaultSize)
428+
.similarityThresholdAll()
429+
.build());
430+
431+
assertThat(results).hasSize(overDefaultSize);
432+
433+
// Remove all documents from the store
434+
vectorStore.delete(testDocs.stream().map(Document::getId).toList());
435+
436+
Awaitility.await()
437+
.until(() -> vectorStore.similaritySearch(
438+
SearchRequest.builder().query("Great Depression").topK(1).similarityThresholdAll().build()),
439+
hasSize(0));
440+
});
441+
}
442+
403443
@SpringBootConfiguration
404444
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
405445
public static class TestApplication {

0 commit comments

Comments
 (0)