From 86650ac78559a5eeb7be517e2aead7bcda9e74a7 Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Fri, 12 Apr 2024 17:42:13 +0200 Subject: [PATCH 01/10] knn instead of script_score, removed initialization --- .../vectorstore/ElasticsearchVectorStore.java | 69 ++++---------- .../ElasticsearchVectorStoreIT.java | 92 ++++++++++--------- 2 files changed, 68 insertions(+), 93 deletions(-) diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 8f672862cd5..634be621adf 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -16,17 +16,13 @@ package org.springframework.ai.vectorstore; import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty; -import co.elastic.clients.elasticsearch._types.mapping.Property; -import co.elastic.clients.elasticsearch._types.query_dsl.Query; +import co.elastic.clients.elasticsearch._types.mapping.TypeMapping; import co.elastic.clients.elasticsearch.core.BulkRequest; import co.elastic.clients.elasticsearch.core.BulkResponse; -import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; +import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.search.Hit; import co.elastic.clients.elasticsearch.indices.CreateIndexResponse; -import co.elastic.clients.json.JsonData; import co.elastic.clients.json.jackson.JacksonJsonpMapper; -import co.elastic.clients.transport.endpoints.BooleanResponse; import co.elastic.clients.transport.rest_client.RestClientTransport; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; @@ -53,9 +49,6 @@ */ public class ElasticsearchVectorStore implements VectorStore, InitializingBean { - // divided by 2 to get score in the range [0, 1] - public static final String COSINE_SIMILARITY_FUNCTION = "(cosineSimilarity(params.query_vector, 'embedding') + 1.0) / 2"; - private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class); private final EmbeddingModel embeddingModel; @@ -84,14 +77,6 @@ public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestCli this.embeddingModel = embeddingModel; this.options = options; this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter(); - // the potential functions for vector fields at - // https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-script-score-query.html#vector-functions - this.similarityFunction = COSINE_SIMILARITY_FUNCTION; - } - - public ElasticsearchVectorStore withSimilarityFunction(String similarityFunction) { - this.similarityFunction = similarityFunction; - return this; } @Override @@ -144,23 +129,23 @@ public List similaritySearch(SearchRequest searchRequest) { searchRequest.getFilterExpression()); } - public List similaritySearch(List embedding, int topK, double similarityThreshold, - Filter.Expression filterExpression) { - return similaritySearch( - new co.elastic.clients.elasticsearch.core.SearchRequest.Builder().index(options.getIndexName()) - .query(getElasticsearchSimilarityQuery(embedding, filterExpression)) - .size(topK) - .minScore(similarityThreshold) - .build()); - } + SearchResponse res = elasticsearchClient.search( + sr -> sr.index(this.index) + .minScore(searchRequest.getSimilarityThreshold()) + .knn(knn -> knn.queryVector(vectors) + .k(searchRequest.getTopK()) + .field(EMBEDDING_FIELD) + .numCandidates((long) (1.5 * searchRequest.getTopK())) + .filter(fl -> fl.queryString( + qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))), + Document.class); - private Query getElasticsearchSimilarityQuery(List embedding, Filter.Expression filterExpression) { - return Query.of(queryBuilder -> queryBuilder.scriptScore(scriptScoreQueryBuilder -> scriptScoreQueryBuilder - .query(queryBuilder2 -> queryBuilder2.queryString(queryStringQuerybuilder -> queryStringQuerybuilder - .query(getElasticsearchQueryString(filterExpression)))) - .script(scriptBuilder -> scriptBuilder - .inline(inlineScriptBuilder -> inlineScriptBuilder.source(this.similarityFunction) - .params("query_vector", JsonData.of(embedding)))))); + return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList()); + + } + catch (IOException e) { + throw new RuntimeException(e); + } } private String getElasticsearchQueryString(Filter.Expression filterExpression) { @@ -169,20 +154,6 @@ private String getElasticsearchQueryString(Filter.Expression filterExpression) { } - private List similaritySearch(co.elastic.clients.elasticsearch.core.SearchRequest searchRequest) { - try { - return this.elasticsearchClient.search(searchRequest, Document.class) - .hits() - .hits() - .stream() - .map(this::toDocument) - .collect(Collectors.toList()); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - private Document toDocument(Hit hit) { Document document = hit.source(); document.getMetadata().put("distance", 1 - hit.score().floatValue()); @@ -191,9 +162,7 @@ private Document toDocument(Hit hit) { private boolean indexExists() { try { - BooleanResponse response = this.elasticsearchClient.indices() - .exists(existRequestBuilder -> existRequestBuilder.index(options.getIndexName())); - return response.value(); + return this.elasticsearchClient.indices().exists(ex -> ex.index(this.index)).value(); } catch (IOException e) { throw new RuntimeException(e); diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index 350c121c431..0c85877a8cb 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -25,6 +25,12 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.mapping.TypeMapping; +import co.elastic.clients.json.jackson.JacksonJsonpMapper; +import co.elastic.clients.transport.rest_client.RestClientTransport; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.http.HttpHost; import org.awaitility.Awaitility; import org.elasticsearch.client.RestClient; @@ -36,7 +42,6 @@ import org.testcontainers.elasticsearch.ElasticsearchContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -62,11 +67,7 @@ class ElasticsearchVectorStoreIT { "docker.elastic.co/elasticsearch/elasticsearch:8.12.2") .withEnv("xpack.security.enabled", "false"); - private static final String DEFAULT = "default cosine similarity"; - - protected final ObjectMapper objectMapper = new ObjectMapper(); - - private List documents = List.of( + private final List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); @@ -92,28 +93,41 @@ private ApplicationContextRunner getContextRunner() { return new ApplicationContextRunner().withUserConfiguration(TestApplication.class); } + private void prepareMapping(String similarityFunction, ElasticsearchVectorStore vectorStore) { + if (!similarityFunction.equals("cosine")) { // cosine is the default similarity + // function, no need for custom + // mapping + + // vector dimension 1536 is openAI specific + TypeMapping mapping = TypeMapping.of(tm -> tm.properties("embedding", + p -> p.denseVector(dv -> dv.dims(1536).index(true).similarity(similarityFunction)))); + + vectorStore.createIndexMapping("spring-ai-document-index", mapping); + } + } + @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); vectorStore.delete(List.of("_all")); + // deleting index so that it can be recreated with new mapping containing a + // different similarity function + ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class); + if (elasticsearchClient.indices().exists(ex -> ex.index("spring-ai-document-index")).value()) { + elasticsearchClient.indices().delete(del -> del.index("spring-ai-document-index")); + } }); } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "max_inner_product" }) public void addAndSearchTest(String similarityFunction) { getContextRunner().run(context -> { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } + prepareMapping(similarityFunction, vectorStore); vectorStore.add(documents); @@ -144,19 +158,13 @@ public void addAndSearchTest(String similarityFunction) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "max_inner_product" }) public void searchWithFilters(String similarityFunction) { getContextRunner().run(context -> { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } + prepareMapping(similarityFunction, vectorStore); var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner", Map.of("country", "BG", "year", 2020, "activationDate", new Date(1000))); @@ -246,18 +254,13 @@ public void searchWithFilters(String similarityFunction) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "max_inner_product" }) public void documentUpdateTest(String similarityFunction) { getContextRunner().run(context -> { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } + + prepareMapping(similarityFunction, vectorStore); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", Map.of("meta1", "meta1")); @@ -306,18 +309,12 @@ public void documentUpdateTest(String similarityFunction) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "max_inner_product" }) public void searchThresholdTest(String similarityFunction) { - getContextRunner().run(context -> { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } + + prepareMapping(similarityFunction, vectorStore); vectorStore.add(documents); @@ -360,10 +357,8 @@ public void searchThresholdTest(String similarityFunction) { public static class TestApplication { @Bean - public ElasticsearchVectorStore vectorStore(EmbeddingModel embeddingModel) { - return new ElasticsearchVectorStore( - RestClient.builder(HttpHost.create(elasticsearchContainer.getHttpHostAddress())).build(), - embeddingModel, true); + public ElasticsearchVectorStore vectorStore(EmbeddingClient embeddingClient, RestClient restClient) { + return new ElasticsearchVectorStore(restClient, embeddingClient); } @Bean @@ -371,6 +366,17 @@ public EmbeddingModel embeddingModel() { return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); } + @Bean + RestClient restClient() { + return RestClient.builder(HttpHost.create(elasticsearchContainer.getHttpHostAddress())).build(); + } + + @Bean + ElasticsearchClient elasticsearchClient(RestClient restClient) { + return new ElasticsearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper( + new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)))); + } + } } From 8180189438a4e270785c9c548aa323e7dfe1ddde Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Tue, 16 Apr 2024 12:37:51 +0200 Subject: [PATCH 02/10] only using normalized similarities, adjusted unit test --- .../spring-ai-elasticsearch-store/pom.xml | 1 - .../vectorstore/ElasticsearchVectorStore.java | 27 ++++++- .../ElasticsearchVectorStoreIT.java | 80 ++++++++++++------- 3 files changed, 73 insertions(+), 35 deletions(-) diff --git a/vector-stores/spring-ai-elasticsearch-store/pom.xml b/vector-stores/spring-ai-elasticsearch-store/pom.xml index 67ac31ac7db..717eca2c6cb 100644 --- a/vector-stores/spring-ai-elasticsearch-store/pom.xml +++ b/vector-stores/spring-ai-elasticsearch-store/pom.xml @@ -45,7 +45,6 @@ test - org.springframework.ai spring-ai-test diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 634be621adf..c76662637d0 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -21,6 +21,7 @@ import co.elastic.clients.elasticsearch.core.BulkResponse; import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.search.Hit; +import co.elastic.clients.elasticsearch.indices.CreateIndexRequest; import co.elastic.clients.elasticsearch.indices.CreateIndexResponse; import co.elastic.clients.json.jackson.JacksonJsonpMapper; import co.elastic.clients.transport.rest_client.RestClientTransport; @@ -42,6 +43,8 @@ import java.util.Optional; import java.util.stream.Collectors; +import static java.lang.Math.sqrt; + /** * @author Jemin Huh * @author Wei Jiang @@ -57,6 +60,8 @@ public class ElasticsearchVectorStore implements VectorStore, InitializingBean { private final ElasticsearchVectorStoreOptions options; + private String similarityFunction = SIMILARITY_DEFAULT; + private final FilterExpressionConverter filterExpressionConverter; private String similarityFunction; @@ -131,10 +136,10 @@ public List similaritySearch(SearchRequest searchRequest) { SearchResponse res = elasticsearchClient.search( sr -> sr.index(this.index) - .minScore(searchRequest.getSimilarityThreshold()) .knn(knn -> knn.queryVector(vectors) + .similarity((float) searchRequest.getSimilarityThreshold()) .k(searchRequest.getTopK()) - .field(EMBEDDING_FIELD) + .field(EMBEDDING) .numCandidates((long) (1.5 * searchRequest.getTopK())) .filter(fl -> fl.queryString( qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))), @@ -156,11 +161,25 @@ private String getElasticsearchQueryString(Filter.Expression filterExpression) { private Document toDocument(Hit hit) { Document document = hit.source(); - document.getMetadata().put("distance", 1 - hit.score().floatValue()); + document.getMetadata().put("distance", calculateDistance(hit.score().floatValue())); return document; } - private boolean indexExists() { + // more info on score/distance calculation + // https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search + private float calculateDistance(Float score) { + switch (similarityFunction) { + case "l2_norm": + // the returned value of l2_norm is the opposite of the other functions + // (closest to zero means more accurate) + return (float) (sqrt((1 / score) - 1)); + // cosine and dot_product + default: + return (2 * score) - 1; + } + } + + public boolean existsIndex() { try { return this.elasticsearchClient.indices().exists(ex -> ex.index(this.index)).value(); } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index 0c85877a8cb..904c78ba5fa 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -102,17 +102,25 @@ private void prepareMapping(String similarityFunction, ElasticsearchVectorStore TypeMapping mapping = TypeMapping.of(tm -> tm.properties("embedding", p -> p.denseVector(dv -> dv.dims(1536).index(true).similarity(similarityFunction)))); - vectorStore.createIndexMapping("spring-ai-document-index", mapping); + vectorStore.createIndexMapping(mapping); } } + private double getThreshold(String similarity) { + // l2_norm works in reverse: accept all threshold = 1, accept none = 0 + if (similarity.equals("l2_norm")) { + return 1.0; + } + return 0.0; + } + @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); vectorStore.delete(List.of("_all")); - // deleting index so that it can be recreated with new mapping containing a - // different similarity function + // deleting index so that it can be recreated with new mapping + // containing a different similarity function ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class); if (elasticsearchClient.indices().exists(ex -> ex.index("spring-ai-document-index")).value()) { elasticsearchClient.indices().delete(del -> del.index("spring-ai-document-index")); @@ -121,7 +129,7 @@ void cleanDatabase() { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "max_inner_product" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void addAndSearchTest(String similarityFunction) { getContextRunner().run(context -> { @@ -131,13 +139,15 @@ public void addAndSearchTest(String similarityFunction) { vectorStore.add(documents); + double threshold = getThreshold(similarityFunction); + Awaitility.await() - .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), + .until(() -> vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(threshold)), hasSize(1)); - List results = vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(threshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -151,14 +161,14 @@ public void addAndSearchTest(String similarityFunction) { vectorStore.delete(documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), + .until(() -> vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(threshold)), hasSize(0)); }); } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "max_inner_product" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void searchWithFilters(String similarityFunction) { getContextRunner().run(context -> { @@ -175,12 +185,16 @@ public void searchWithFilters(String similarityFunction) { vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); + double threshold = getThreshold(similarityFunction); + Awaitility.await() - .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)), hasSize(3)); + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("The World").withTopK(5).withSimilarityThreshold(threshold)), + hasSize(3)); List results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThresholdAll() + .withSimilarityThreshold(threshold) .withFilterExpression("country == 'NL'")); assertThat(results).hasSize(1); @@ -188,7 +202,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThresholdAll() + .withSimilarityThreshold(threshold) .withFilterExpression("country == 'BG'")); assertThat(results).hasSize(2); @@ -197,7 +211,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThresholdAll() + .withSimilarityThreshold(threshold) .withFilterExpression("country == 'BG' && year == 2020")); assertThat(results).hasSize(1); @@ -205,7 +219,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThresholdAll() + .withSimilarityThreshold(threshold) .withFilterExpression("country in ['BG']")); assertThat(results).hasSize(2); @@ -214,14 +228,14 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThresholdAll() + .withSimilarityThreshold(threshold) .withFilterExpression("country in ['BG','NL']")); assertThat(results).hasSize(3); results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThresholdAll() + .withSimilarityThreshold(threshold) .withFilterExpression("country not in ['BG']")); assertThat(results).hasSize(1); @@ -229,7 +243,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThresholdAll() + .withSimilarityThreshold(threshold) .withFilterExpression("NOT(country not in ['BG'])")); assertThat(results).hasSize(2); @@ -238,7 +252,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThresholdAll() + .withSimilarityThreshold(threshold) .withFilterExpression( "activationDate > " + ZonedDateTime.parse("1970-01-01T00:00:02Z").toInstant().toEpochMilli())); @@ -254,7 +268,7 @@ public void searchWithFilters(String similarityFunction) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "max_inner_product" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void documentUpdateTest(String similarityFunction) { getContextRunner().run(context -> { @@ -266,13 +280,15 @@ public void documentUpdateTest(String similarityFunction) { Map.of("meta1", "meta1")); vectorStore.add(List.of(document)); + double threshold = getThreshold(similarityFunction); + Awaitility.await() .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(0).withTopK(5)), + .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(threshold).withTopK(5)), hasSize(1)); List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(0).withTopK(5)); + .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(threshold).withTopK(5)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -285,7 +301,9 @@ public void documentUpdateTest(String similarityFunction) { "The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2")); vectorStore.add(List.of(sameIdDocument)); - SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5); + SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar") + .withTopK(5) + .withSimilarityThreshold(threshold); Awaitility.await() .until(() -> vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(), @@ -309,7 +327,7 @@ public void documentUpdateTest(String similarityFunction) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "max_inner_product" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void searchThresholdTest(String similarityFunction) { getContextRunner().run(context -> { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); @@ -318,9 +336,11 @@ public void searchThresholdTest(String similarityFunction) { vectorStore.add(documents); + double threshold = getThreshold(similarityFunction); + SearchRequest query = SearchRequest.query("Great Depression") .withTopK(50) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); + .withSimilarityThreshold(threshold); Awaitility.await().until(() -> vectorStore.similaritySearch(query), hasSize(3)); @@ -330,10 +350,10 @@ public void searchThresholdTest(String similarityFunction) { assertThat(distances).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + float thresholdResult = (distances.get(0) + distances.get(1)) / 2; List results = vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(1 - threshold)); + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(thresholdResult)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -346,8 +366,8 @@ public void searchThresholdTest(String similarityFunction) { vectorStore.delete(documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(0)), + .until(() -> vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(threshold)), hasSize(0)); }); } From a73509b908b975d0e917798e766cc9d3bd99d287 Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Tue, 16 Apr 2024 12:44:19 +0200 Subject: [PATCH 03/10] import clean --- .../springframework/ai/vectorstore/ElasticsearchVectorStore.java | 1 - 1 file changed, 1 deletion(-) diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index c76662637d0..b6c0ddedffc 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -21,7 +21,6 @@ import co.elastic.clients.elasticsearch.core.BulkResponse; import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.search.Hit; -import co.elastic.clients.elasticsearch.indices.CreateIndexRequest; import co.elastic.clients.elasticsearch.indices.CreateIndexResponse; import co.elastic.clients.json.jackson.JacksonJsonpMapper; import co.elastic.clients.transport.rest_client.RestClientTransport; From 0ab99810c37abe9bad4f170f752005e0122cc608 Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Tue, 16 Apr 2024 14:50:59 +0200 Subject: [PATCH 04/10] making l2norm's distances consistent with others --- .../vectorstore/ElasticsearchVectorStore.java | 23 ++++-- .../ElasticsearchVectorStoreIT.java | 76 ++++++++----------- 2 files changed, 47 insertions(+), 52 deletions(-) diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index b6c0ddedffc..4df12a91e07 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -128,15 +128,22 @@ private BulkResponse bulkRequest(BulkRequest bulkRequest) { @Override public List similaritySearch(SearchRequest searchRequest) { Assert.notNull(searchRequest, "The search request must not be null."); - return similaritySearch(this.embeddingModel.embed(searchRequest.getQuery()), searchRequest.getTopK(), - Double.valueOf(searchRequest.getSimilarityThreshold()).floatValue(), - searchRequest.getFilterExpression()); - } + try { + float threshold = (float) searchRequest.getSimilarityThreshold(); + // reverting l2_norm distance to its original value + if (similarityFunction.equals("l2_norm")) { + threshold = 1 - threshold; + } + final float finalThreshold = threshold; + List vectors = this.embeddingClient.embed(searchRequest.getQuery()) + .stream() + .map(Double::floatValue) + .toList(); SearchResponse res = elasticsearchClient.search( sr -> sr.index(this.index) .knn(knn -> knn.queryVector(vectors) - .similarity((float) searchRequest.getSimilarityThreshold()) + .similarity(finalThreshold) .k(searchRequest.getTopK()) .field(EMBEDDING) .numCandidates((long) (1.5 * searchRequest.getTopK())) @@ -170,8 +177,10 @@ private float calculateDistance(Float score) { switch (similarityFunction) { case "l2_norm": // the returned value of l2_norm is the opposite of the other functions - // (closest to zero means more accurate) - return (float) (sqrt((1 / score) - 1)); + // (closest to zero means more accurate), so to make it consistent + // with the other functions the reverse is returned applying a "1-" + // to the standard transformation + return (float) (1 - (sqrt((1 / score) - 1))); // cosine and dot_product default: return (2 * score) - 1; diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index 904c78ba5fa..b7a22415673 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -106,14 +106,6 @@ private void prepareMapping(String similarityFunction, ElasticsearchVectorStore } } - private double getThreshold(String similarity) { - // l2_norm works in reverse: accept all threshold = 1, accept none = 0 - if (similarity.equals("l2_norm")) { - return 1.0; - } - return 0.0; - } - @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { @@ -139,15 +131,14 @@ public void addAndSearchTest(String similarityFunction) { vectorStore.add(documents); - double threshold = getThreshold(similarityFunction); - Awaitility.await() - .until(() -> vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(threshold)), - hasSize(1)); + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression") + .withTopK(1) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)), hasSize(1)); - List results = vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(threshold)); + List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression") + .withTopK(1) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -161,9 +152,9 @@ public void addAndSearchTest(String similarityFunction) { vectorStore.delete(documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(threshold)), - hasSize(0)); + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression") + .withTopK(1) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)), hasSize(0)); }); } @@ -185,16 +176,14 @@ public void searchWithFilters(String similarityFunction) { vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); - double threshold = getThreshold(similarityFunction); - Awaitility.await() - .until(() -> vectorStore - .similaritySearch(SearchRequest.query("The World").withTopK(5).withSimilarityThreshold(threshold)), - hasSize(3)); + .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)), hasSize(3)); List results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(threshold) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) .withFilterExpression("country == 'NL'")); assertThat(results).hasSize(1); @@ -202,7 +191,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(threshold) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) .withFilterExpression("country == 'BG'")); assertThat(results).hasSize(2); @@ -211,7 +200,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(threshold) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) .withFilterExpression("country == 'BG' && year == 2020")); assertThat(results).hasSize(1); @@ -219,7 +208,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(threshold) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) .withFilterExpression("country in ['BG']")); assertThat(results).hasSize(2); @@ -228,14 +217,14 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(threshold) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) .withFilterExpression("country in ['BG','NL']")); assertThat(results).hasSize(3); results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(threshold) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) .withFilterExpression("country not in ['BG']")); assertThat(results).hasSize(1); @@ -243,7 +232,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(threshold) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) .withFilterExpression("NOT(country not in ['BG'])")); assertThat(results).hasSize(2); @@ -252,7 +241,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(threshold) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) .withFilterExpression( "activationDate > " + ZonedDateTime.parse("1970-01-01T00:00:02Z").toInstant().toEpochMilli())); @@ -280,15 +269,14 @@ public void documentUpdateTest(String similarityFunction) { Map.of("meta1", "meta1")); vectorStore.add(List.of(document)); - double threshold = getThreshold(similarityFunction); - Awaitility.await() - .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(threshold).withTopK(5)), - hasSize(1)); + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Spring") + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withTopK(5)), hasSize(1)); - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(threshold).withTopK(5)); + List results = vectorStore.similaritySearch(SearchRequest.query("Spring") + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withTopK(5)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -303,7 +291,7 @@ public void documentUpdateTest(String similarityFunction) { vectorStore.add(List.of(sameIdDocument)); SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar") .withTopK(5) - .withSimilarityThreshold(threshold); + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); Awaitility.await() .until(() -> vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(), @@ -336,11 +324,9 @@ public void searchThresholdTest(String similarityFunction) { vectorStore.add(documents); - double threshold = getThreshold(similarityFunction); - SearchRequest query = SearchRequest.query("Great Depression") .withTopK(50) - .withSimilarityThreshold(threshold); + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); Awaitility.await().until(() -> vectorStore.similaritySearch(query), hasSize(3)); @@ -366,9 +352,9 @@ public void searchThresholdTest(String similarityFunction) { vectorStore.delete(documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(threshold)), - hasSize(0)); + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression") + .withTopK(50) + .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)), hasSize(0)); }); } From 3e31e005ffad30633b2c289ed49bb11d8a2d7aff Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Tue, 16 Apr 2024 15:02:47 +0200 Subject: [PATCH 05/10] refactor unit test --- .../ElasticsearchVectorStoreIT.java | 63 +++++++++---------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index b7a22415673..9b511b77783 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -132,13 +132,12 @@ public void addAndSearchTest(String similarityFunction) { vectorStore.add(documents); Awaitility.await() - .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression") - .withTopK(1) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)), hasSize(1)); + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll()), + hasSize(1)); - List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression") - .withTopK(1) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)); + List results = vectorStore + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll()); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -152,9 +151,9 @@ public void addAndSearchTest(String similarityFunction) { vectorStore.delete(documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression") - .withTopK(1) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)), hasSize(0)); + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll()), + hasSize(0)); }); } @@ -177,13 +176,13 @@ public void searchWithFilters(String similarityFunction) { vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); Awaitility.await() - .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)), hasSize(3)); + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("The World").withTopK(5).withSimilarityThresholdAll()), + hasSize(3)); List results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withSimilarityThresholdAll() .withFilterExpression("country == 'NL'")); assertThat(results).hasSize(1); @@ -191,7 +190,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withSimilarityThresholdAll() .withFilterExpression("country == 'BG'")); assertThat(results).hasSize(2); @@ -200,7 +199,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withSimilarityThresholdAll() .withFilterExpression("country == 'BG' && year == 2020")); assertThat(results).hasSize(1); @@ -208,7 +207,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withSimilarityThresholdAll() .withFilterExpression("country in ['BG']")); assertThat(results).hasSize(2); @@ -217,14 +216,14 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withSimilarityThresholdAll() .withFilterExpression("country in ['BG','NL']")); assertThat(results).hasSize(3); results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withSimilarityThresholdAll() .withFilterExpression("country not in ['BG']")); assertThat(results).hasSize(1); @@ -232,7 +231,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withSimilarityThresholdAll() .withFilterExpression("NOT(country not in ['BG'])")); assertThat(results).hasSize(2); @@ -241,7 +240,7 @@ public void searchWithFilters(String similarityFunction) { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) + .withSimilarityThresholdAll() .withFilterExpression( "activationDate > " + ZonedDateTime.parse("1970-01-01T00:00:02Z").toInstant().toEpochMilli())); @@ -270,13 +269,12 @@ public void documentUpdateTest(String similarityFunction) { vectorStore.add(List.of(document)); Awaitility.await() - .until(() -> vectorStore.similaritySearch(SearchRequest.query("Spring") - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) - .withTopK(5)), hasSize(1)); + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("Spring").withSimilarityThresholdAll().withTopK(5)), + hasSize(1)); - List results = vectorStore.similaritySearch(SearchRequest.query("Spring") - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) - .withTopK(5)); + List results = vectorStore + .similaritySearch(SearchRequest.query("Spring").withSimilarityThresholdAll().withTopK(5)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -289,9 +287,7 @@ public void documentUpdateTest(String similarityFunction) { "The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2")); vectorStore.add(List.of(sameIdDocument)); - SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar") - .withTopK(5) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); + SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5).withSimilarityThresholdAll(); Awaitility.await() .until(() -> vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(), @@ -324,9 +320,7 @@ public void searchThresholdTest(String similarityFunction) { vectorStore.add(documents); - SearchRequest query = SearchRequest.query("Great Depression") - .withTopK(50) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); + SearchRequest query = SearchRequest.query("Great Depression").withTopK(50).withSimilarityThresholdAll(); Awaitility.await().until(() -> vectorStore.similaritySearch(query), hasSize(3)); @@ -352,9 +346,8 @@ public void searchThresholdTest(String similarityFunction) { vectorStore.delete(documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression") - .withTopK(50) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL)), hasSize(0)); + .until(() -> vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThresholdAll()), hasSize(0)); }); } From af5d8c1e2366f523cd12db67fa1641c41b0f5366 Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Fri, 3 May 2024 16:14:56 +0200 Subject: [PATCH 06/10] rebase --- ...ticsearchVectorStoreAutoConfiguration.java | 7 +-- .../ElasticsearchVectorStoreProperties.java | 20 ++----- .../spring-ai-elasticsearch-store/pom.xml | 1 + .../vectorstore/ElasticsearchVectorStore.java | 48 +++++++-------- .../ElasticsearchVectorStoreOptions.java | 19 +----- .../ai/vectorstore/SimilarityFunction.java | 12 ++++ .../ElasticsearchVectorStoreIT.java | 60 +++++++++---------- 7 files changed, 70 insertions(+), 97 deletions(-) create mode 100644 vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java index 78a8fe0fad5..12bb5c2293b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java @@ -28,6 +28,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; +import java.util.Objects; + /** * @author Eddú Meléndez * @author Wei Jiang @@ -52,10 +54,7 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti if (properties.getDimensions() != null) { elasticsearchVectorStoreOptions.setDimensions(properties.getDimensions()); } - if (properties.isDenseVectorIndexing() != null) { - elasticsearchVectorStoreOptions.setDenseVectorIndexing(properties.isDenseVectorIndexing()); - } - if (StringUtils.hasText(properties.getSimilarity())) { + if (properties.getSimilarity() != null) { elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity()); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java index 67f677f6251..cf456b137d5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java @@ -16,6 +16,7 @@ package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; import org.springframework.ai.autoconfigure.CommonVectorStoreProperties; +import org.springframework.ai.vectorstore.SimilarityFunction; import org.springframework.boot.context.properties.ConfigurationProperties; /** @@ -37,15 +38,10 @@ public class ElasticsearchVectorStoreProperties extends CommonVectorStorePropert */ private Integer dimensions; - /** - * Whether to use dense vector indexing. - */ - private Boolean denseVectorIndexing; - /** * The similarity function to use. */ - private String similarity; + private SimilarityFunction similarity; public String getIndexName() { return this.indexName; @@ -63,19 +59,11 @@ public void setDimensions(Integer dimensions) { this.dimensions = dimensions; } - public Boolean isDenseVectorIndexing() { - return denseVectorIndexing; - } - - public void setDenseVectorIndexing(Boolean denseVectorIndexing) { - this.denseVectorIndexing = denseVectorIndexing; - } - - public String getSimilarity() { + public SimilarityFunction getSimilarity() { return similarity; } - public void setSimilarity(String similarity) { + public void setSimilarity(SimilarityFunction similarity) { this.similarity = similarity; } diff --git a/vector-stores/spring-ai-elasticsearch-store/pom.xml b/vector-stores/spring-ai-elasticsearch-store/pom.xml index 717eca2c6cb..4e5a0ceef97 100644 --- a/vector-stores/spring-ai-elasticsearch-store/pom.xml +++ b/vector-stores/spring-ai-elasticsearch-store/pom.xml @@ -35,6 +35,7 @@ co.elastic.clients elasticsearch-java + 8.13.2 diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 4df12a91e07..93bb36cedd7 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -16,10 +16,12 @@ package org.springframework.ai.vectorstore; import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch._types.mapping.TypeMapping; +import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty; +import co.elastic.clients.elasticsearch._types.mapping.Property; import co.elastic.clients.elasticsearch.core.BulkRequest; import co.elastic.clients.elasticsearch.core.BulkResponse; import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; import co.elastic.clients.elasticsearch.core.search.Hit; import co.elastic.clients.elasticsearch.indices.CreateIndexResponse; import co.elastic.clients.json.jackson.JacksonJsonpMapper; @@ -43,6 +45,7 @@ import java.util.stream.Collectors; import static java.lang.Math.sqrt; +import static org.springframework.ai.vectorstore.SimilarityFunction.l2_norm; /** * @author Jemin Huh @@ -59,8 +62,6 @@ public class ElasticsearchVectorStore implements VectorStore, InitializingBean { private final ElasticsearchVectorStoreOptions options; - private String similarityFunction = SIMILARITY_DEFAULT; - private final FilterExpressionConverter filterExpressionConverter; private String similarityFunction; @@ -85,18 +86,18 @@ public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestCli @Override public void add(List documents) { - BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder(); + BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); for (Document document : documents) { if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().isEmpty()) { logger.debug("Calling EmbeddingModel for document id = " + document.getId()); document.setEmbedding(this.embeddingModel.embed(document)); } - builkRequestBuilder.operations(op -> op + bulkRequestBuilder.operations(op -> op .index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(document))); } - BulkResponse bulkRequest = bulkRequest(builkRequestBuilder.build()); + BulkResponse bulkRequest = bulkRequest(bulkRequestBuilder.build()); if (bulkRequest.errors()) { List bulkResponseItems = bulkRequest.items(); @@ -110,10 +111,10 @@ public void add(List documents) { @Override public Optional delete(List idList) { - BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder(); + BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); for (String id : idList) - builkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id))); - return Optional.of(bulkRequest(builkRequestBuilder.build()).errors()); + bulkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id))); + return Optional.of(bulkRequest(bulkRequestBuilder.build()).errors()); } private BulkResponse bulkRequest(BulkRequest bulkRequest) { @@ -131,7 +132,7 @@ public List similaritySearch(SearchRequest searchRequest) { try { float threshold = (float) searchRequest.getSimilarityThreshold(); // reverting l2_norm distance to its original value - if (similarityFunction.equals("l2_norm")) { + if (options.getSimilarity().equals(l2_norm)) { threshold = 1 - threshold; } final float finalThreshold = threshold; @@ -141,11 +142,11 @@ public List similaritySearch(SearchRequest searchRequest) { .toList(); SearchResponse res = elasticsearchClient.search( - sr -> sr.index(this.index) + sr -> sr.index(options.getIndexName()) .knn(knn -> knn.queryVector(vectors) .similarity(finalThreshold) .k(searchRequest.getTopK()) - .field(EMBEDDING) + .field("embedding") .numCandidates((long) (1.5 * searchRequest.getTopK())) .filter(fl -> fl.queryString( qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))), @@ -174,8 +175,8 @@ private Document toDocument(Hit hit) { // more info on score/distance calculation // https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search private float calculateDistance(Float score) { - switch (similarityFunction) { - case "l2_norm": + switch (options.getSimilarity()) { + case l2_norm: // the returned value of l2_norm is the opposite of the other functions // (closest to zero means more accurate), so to make it consistent // with the other functions the reverse is returned applying a "1-" @@ -187,9 +188,9 @@ private float calculateDistance(Float score) { } } - public boolean existsIndex() { + public boolean indexExists() { try { - return this.elasticsearchClient.indices().exists(ex -> ex.index(this.index)).value(); + return this.elasticsearchClient.indices().exists(ex -> ex.index(options.getIndexName())).value(); } catch (IOException e) { throw new RuntimeException(e); @@ -199,18 +200,9 @@ public boolean existsIndex() { private CreateIndexResponse createIndexMapping() { try { return this.elasticsearchClient.indices() - .create(createIndexBuilder -> createIndexBuilder.index(options.getIndexName()) - .mappings(typeMappingBuilder -> { - typeMappingBuilder.properties("embedding", - new Property.Builder() - .denseVector(new DenseVectorProperty.Builder().dims(options.getDimensions()) - .similarity(options.getSimilarity()) - .index(options.isDenseVectorIndexing()) - .build()) - .build()); - - return typeMappingBuilder; - })); + .create(cr -> cr.index(options.getIndexName()) + .mappings(map -> map.properties("embedding", p -> p.denseVector( + dv -> dv.similarity(options.getSimilarity().toString()).dims(options.getDimensions()))))); } catch (IOException e) { throw new RuntimeException(e); diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java index 6cc7794f50a..c685122461c 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java @@ -34,15 +34,10 @@ public class ElasticsearchVectorStoreOptions { */ private int dimensions = 1536; - /** - * Whether to use dense vector indexing. - */ - private boolean denseVectorIndexing = true; - /** * The similarity function to use. */ - private String similarity = "cosine"; + private SimilarityFunction similarity = SimilarityFunction.cosine; public String getIndexName() { return indexName; @@ -60,19 +55,11 @@ public void setDimensions(int dims) { this.dimensions = dims; } - public boolean isDenseVectorIndexing() { - return denseVectorIndexing; - } - - public void setDenseVectorIndexing(boolean denseVectorIndexing) { - this.denseVectorIndexing = denseVectorIndexing; - } - - public String getSimilarity() { + public SimilarityFunction getSimilarity() { return similarity; } - public void setSimilarity(String similarity) { + public void setSimilarity(SimilarityFunction similarity) { this.similarity = similarity; } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java new file mode 100644 index 00000000000..366e2bb4eb5 --- /dev/null +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java @@ -0,0 +1,12 @@ +package org.springframework.ai.vectorstore; + +/* +https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html +max_inner_product is currently not supported because the distance value is not +normalized and would not comply with the requirement of being between 0 and 1 +*/ +public enum SimilarityFunction { + + l2_norm, dot_product, cosine + +} diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index 9b511b77783..6cfb5e5c4e9 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -26,7 +26,7 @@ import java.util.concurrent.TimeUnit; import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch._types.mapping.TypeMapping; +import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord; import co.elastic.clients.json.jackson.JacksonJsonpMapper; import co.elastic.clients.transport.rest_client.RestClientTransport; import com.fasterxml.jackson.databind.DeserializationFeature; @@ -93,29 +93,15 @@ private ApplicationContextRunner getContextRunner() { return new ApplicationContextRunner().withUserConfiguration(TestApplication.class); } - private void prepareMapping(String similarityFunction, ElasticsearchVectorStore vectorStore) { - if (!similarityFunction.equals("cosine")) { // cosine is the default similarity - // function, no need for custom - // mapping - - // vector dimension 1536 is openAI specific - TypeMapping mapping = TypeMapping.of(tm -> tm.properties("embedding", - p -> p.denseVector(dv -> dv.dims(1536).index(true).similarity(similarityFunction)))); - - vectorStore.createIndexMapping(mapping); - } - } @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.delete(List.of("_all")); - // deleting index so that it can be recreated with new mapping - // containing a different similarity function + // deleting indices and data before following tests ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class); - if (elasticsearchClient.indices().exists(ex -> ex.index("spring-ai-document-index")).value()) { - elasticsearchClient.indices().delete(del -> del.index("spring-ai-document-index")); + List indices = elasticsearchClient.cat().indices().valueBody().stream().map(IndicesRecord::index).toList(); + if(!indices.isEmpty()) { + elasticsearchClient.indices().delete(del -> del.index(indices)); } }); } @@ -125,9 +111,8 @@ void cleanDatabase() { public void addAndSearchTest(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - prepareMapping(similarityFunction, vectorStore); + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_"+similarityFunction, ElasticsearchVectorStore.class); vectorStore.add(documents); @@ -162,9 +147,7 @@ public void addAndSearchTest(String similarityFunction) { public void searchWithFilters(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - - prepareMapping(similarityFunction, vectorStore); + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_"+similarityFunction, ElasticsearchVectorStore.class); var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner", Map.of("country", "BG", "year", 2020, "activationDate", new Date(1000))); @@ -260,9 +243,7 @@ public void searchWithFilters(String similarityFunction) { public void documentUpdateTest(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - - prepareMapping(similarityFunction, vectorStore); + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_"+similarityFunction, ElasticsearchVectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", Map.of("meta1", "meta1")); @@ -314,9 +295,7 @@ public void documentUpdateTest(String similarityFunction) { @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void searchThresholdTest(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - - prepareMapping(similarityFunction, vectorStore); + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_"+similarityFunction, ElasticsearchVectorStore.class); vectorStore.add(documents); @@ -355,11 +334,27 @@ public void searchThresholdTest(String similarityFunction) { @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication { - @Bean - public ElasticsearchVectorStore vectorStore(EmbeddingClient embeddingClient, RestClient restClient) { + @Bean("vectorStore_cosine") + public ElasticsearchVectorStore vectorStoreDefault(EmbeddingClient embeddingClient, RestClient restClient) { return new ElasticsearchVectorStore(restClient, embeddingClient); } + @Bean("vectorStore_l2_norm") + public ElasticsearchVectorStore vectorStoreL2(EmbeddingClient embeddingClient, RestClient restClient) { + ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); + options.setIndexName("index_l2"); + options.setSimilarity(SimilarityFunction.l2_norm); + return new ElasticsearchVectorStore(options,restClient, embeddingClient); + } + + @Bean("vectorStore_dot_product") + public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingClient embeddingClient, RestClient restClient) { + ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); + options.setIndexName("index_dot_product"); + options.setSimilarity(SimilarityFunction.dot_product); + return new ElasticsearchVectorStore(options,restClient, embeddingClient); + } + @Bean public EmbeddingModel embeddingModel() { return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); @@ -377,5 +372,4 @@ ElasticsearchClient elasticsearchClient(RestClient restClient) { } } - } From 5acdfec47ccc9c98107a73b3b0d00c0f3ade4071 Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Fri, 3 May 2024 16:15:45 +0200 Subject: [PATCH 07/10] format --- .../ElasticsearchVectorStoreIT.java | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index 6cfb5e5c4e9..1c11aa95254 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -93,14 +93,13 @@ private ApplicationContextRunner getContextRunner() { return new ApplicationContextRunner().withUserConfiguration(TestApplication.class); } - @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { // deleting indices and data before following tests ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class); List indices = elasticsearchClient.cat().indices().valueBody().stream().map(IndicesRecord::index).toList(); - if(!indices.isEmpty()) { + if (!indices.isEmpty()) { elasticsearchClient.indices().delete(del -> del.index(indices)); } }); @@ -112,7 +111,8 @@ public void addAndSearchTest(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_"+similarityFunction, ElasticsearchVectorStore.class); + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, + ElasticsearchVectorStore.class); vectorStore.add(documents); @@ -147,7 +147,8 @@ public void addAndSearchTest(String similarityFunction) { public void searchWithFilters(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_"+similarityFunction, ElasticsearchVectorStore.class); + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, + ElasticsearchVectorStore.class); var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner", Map.of("country", "BG", "year", 2020, "activationDate", new Date(1000))); @@ -243,7 +244,8 @@ public void searchWithFilters(String similarityFunction) { public void documentUpdateTest(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_"+similarityFunction, ElasticsearchVectorStore.class); + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, + ElasticsearchVectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", Map.of("meta1", "meta1")); @@ -295,7 +297,8 @@ public void documentUpdateTest(String similarityFunction) { @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void searchThresholdTest(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_"+similarityFunction, ElasticsearchVectorStore.class); + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, + ElasticsearchVectorStore.class); vectorStore.add(documents); @@ -344,7 +347,7 @@ public ElasticsearchVectorStore vectorStoreL2(EmbeddingClient embeddingClient, R ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); options.setIndexName("index_l2"); options.setSimilarity(SimilarityFunction.l2_norm); - return new ElasticsearchVectorStore(options,restClient, embeddingClient); + return new ElasticsearchVectorStore(options, restClient, embeddingClient); } @Bean("vectorStore_dot_product") @@ -352,7 +355,7 @@ public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingClient embeddingC ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); options.setIndexName("index_dot_product"); options.setSimilarity(SimilarityFunction.dot_product); - return new ElasticsearchVectorStore(options,restClient, embeddingClient); + return new ElasticsearchVectorStore(options, restClient, embeddingClient); } @Bean @@ -372,4 +375,5 @@ ElasticsearchClient elasticsearchClient(RestClient restClient) { } } + } From db32740e6896efd23e48fab9455e224492d82097 Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Mon, 6 May 2024 17:17:12 +0200 Subject: [PATCH 08/10] dependency version, docs --- pom.xml | 1 + .../pages/api/vectordbs/elasticsearch.adoc | 18 +++++++++++++----- spring-ai-spring-boot-autoconfigure/pom.xml | 7 +++++++ ...sticsearchVectorStoreAutoConfiguration.java | 2 -- .../spring-ai-elasticsearch-store/pom.xml | 2 +- .../vectorstore/ElasticsearchVectorStore.java | 4 +--- 6 files changed, 23 insertions(+), 11 deletions(-) diff --git a/pom.xml b/pom.xml index f9abde46034..4299c030c58 100644 --- a/pom.xml +++ b/pom.xml @@ -157,6 +157,7 @@ 0.1.4 2.20.11 42.7.2 + 8.13.3 2.3.4 0.8.0 2.0.46 diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc index 9fc6a895c09..a551c86d44c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc @@ -134,11 +134,18 @@ Properties starting with the `spring.ai.vectorstore.elasticsearch.*` prefix are |`spring.ai.vectorstore.elasticsearch.index-name` | The name of the index to store the vectors. | spring-ai-document-index |`spring.ai.vectorstore.elasticsearch.dimensions` | The number of dimensions in the vector. | 1536 -|`spring.ai.vectorstore.elasticsearch.dense-vector-indexing` | Whether to use dense vector indexing. | true |`spring.ai.vectorstore.elasticsearch.similarity` | The similarity function to use. | `cosine` |`spring.ai.vectorstore.elasticsearch.initialize-schema`| whether to initialize the required schema | `false` |=== +The following similarity functions are available: + +* cosine +* l2_norm +* dot_product + +More details about each in the https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html#dense-vector-params[Elasticsearch Documentation] on dense vectors. + == Metadata Filtering You can leverage the generic, portable xref:api/vectordbs.adoc#metadata-filters[metadata filters] with Elasticsearch as well. @@ -214,10 +221,11 @@ Read the link:https://www.elastic.co/guide/en/elasticsearch/client/java-api-clie ---- @Bean public RestClient restClient() { - RestClientBuilder builder = RestClient.builder(new HttpHost("", 9200, "http")); - Header[] defaultHeaders = new Header[] { new BasicHeader("Authorization", "Basic ") }; - builder.setDefaultHeaders(defaultHeaders); - return builder.build(); + RestClient.builder(new HttpHost("", 9200, "http")) + .setDefaultHeaders(new Header[]{ + new BasicHeader("Authorization", "Basic ") + }) + .build(); } ---- diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 84bd54a188a..5887b84db21 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -260,6 +260,7 @@ true + org.springframework.ai spring-ai-elasticsearch-store @@ -281,6 +282,12 @@ true + + co.elastic.clients + elasticsearch-java + ${elasticsearch-java.version} + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java index 12bb5c2293b..1522abeae35 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java @@ -28,8 +28,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; -import java.util.Objects; - /** * @author Eddú Meléndez * @author Wei Jiang diff --git a/vector-stores/spring-ai-elasticsearch-store/pom.xml b/vector-stores/spring-ai-elasticsearch-store/pom.xml index 4e5a0ceef97..67ae9c969a8 100644 --- a/vector-stores/spring-ai-elasticsearch-store/pom.xml +++ b/vector-stores/spring-ai-elasticsearch-store/pom.xml @@ -35,7 +35,7 @@ co.elastic.clients elasticsearch-java - 8.13.2 + ${elasticsearch-java.version} diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 93bb36cedd7..8517b91b681 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -16,8 +16,6 @@ package org.springframework.ai.vectorstore; import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty; -import co.elastic.clients.elasticsearch._types.mapping.Property; import co.elastic.clients.elasticsearch.core.BulkRequest; import co.elastic.clients.elasticsearch.core.BulkResponse; import co.elastic.clients.elasticsearch.core.SearchResponse; @@ -145,7 +143,7 @@ public List similaritySearch(SearchRequest searchRequest) { sr -> sr.index(options.getIndexName()) .knn(knn -> knn.queryVector(vectors) .similarity(finalThreshold) - .k(searchRequest.getTopK()) + .k((long) searchRequest.getTopK()) .field("embedding") .numCandidates((long) (1.5 * searchRequest.getTopK())) .filter(fl -> fl.queryString( From 8e08cb083352315e67caea7ffeae62e783aeef74 Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Fri, 31 May 2024 17:19:27 +0200 Subject: [PATCH 09/10] rebase --- .../ai/vectorstore/ElasticsearchVectorStore.java | 8 +++----- .../ai/vectorstore/SimilarityFunction.java | 13 ++++++++----- .../ai/vectorstore/ElasticsearchVectorStoreIT.java | 14 +++++++------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 8517b91b681..e481dbaf2bf 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -48,6 +48,7 @@ /** * @author Jemin Huh * @author Wei Jiang + * @author Laura Trotta * @since 1.0.0 */ public class ElasticsearchVectorStore implements VectorStore, InitializingBean { @@ -62,8 +63,6 @@ public class ElasticsearchVectorStore implements VectorStore, InitializingBean { private final FilterExpressionConverter filterExpressionConverter; - private String similarityFunction; - private final boolean initializeSchema; public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) { @@ -134,7 +133,7 @@ public List similaritySearch(SearchRequest searchRequest) { threshold = 1 - threshold; } final float finalThreshold = threshold; - List vectors = this.embeddingClient.embed(searchRequest.getQuery()) + List vectors = this.embeddingModel.embed(searchRequest.getQuery()) .stream() .map(Double::floatValue) .toList(); @@ -151,7 +150,6 @@ public List similaritySearch(SearchRequest searchRequest) { Document.class); return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList()); - } catch (IOException e) { throw new RuntimeException(e); @@ -219,4 +217,4 @@ public void afterPropertiesSet() { } } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java index 366e2bb4eb5..86fc84c01c0 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java @@ -1,10 +1,13 @@ package org.springframework.ai.vectorstore; -/* -https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html -max_inner_product is currently not supported because the distance value is not -normalized and would not comply with the requirement of being between 0 and 1 -*/ +/** + * https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html + * max_inner_product is currently not supported because the distance value is not + * normalized and would not comply with the requirement of being between 0 and 1 + * + * @author Laura Trotta + * @since 1.0.0 + */ public enum SimilarityFunction { l2_norm, dot_product, cosine diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index 1c11aa95254..0832042e5c1 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -64,7 +64,7 @@ class ElasticsearchVectorStoreIT { @Container private static final ElasticsearchContainer elasticsearchContainer = new ElasticsearchContainer( - "docker.elastic.co/elasticsearch/elasticsearch:8.12.2") + "docker.elastic.co/elasticsearch/elasticsearch:8.13.3") .withEnv("xpack.security.enabled", "false"); private final List documents = List.of( @@ -338,24 +338,24 @@ public void searchThresholdTest(String similarityFunction) { public static class TestApplication { @Bean("vectorStore_cosine") - public ElasticsearchVectorStore vectorStoreDefault(EmbeddingClient embeddingClient, RestClient restClient) { - return new ElasticsearchVectorStore(restClient, embeddingClient); + public ElasticsearchVectorStore vectorStoreDefault(EmbeddingModel embeddingModel, RestClient restClient) { + return new ElasticsearchVectorStore(restClient, embeddingModel, true); } @Bean("vectorStore_l2_norm") - public ElasticsearchVectorStore vectorStoreL2(EmbeddingClient embeddingClient, RestClient restClient) { + public ElasticsearchVectorStore vectorStoreL2(EmbeddingModel embeddingModel, RestClient restClient) { ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); options.setIndexName("index_l2"); options.setSimilarity(SimilarityFunction.l2_norm); - return new ElasticsearchVectorStore(options, restClient, embeddingClient); + return new ElasticsearchVectorStore(options, restClient, embeddingModel,true); } @Bean("vectorStore_dot_product") - public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingClient embeddingClient, RestClient restClient) { + public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingModel, RestClient restClient) { ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); options.setIndexName("index_dot_product"); options.setSimilarity(SimilarityFunction.dot_product); - return new ElasticsearchVectorStore(options, restClient, embeddingClient); + return new ElasticsearchVectorStore(options, restClient, embeddingModel,true); } @Bean From ca9fc88e7981122a1abe12c23b7ef1a9591e6446 Mon Sep 17 00:00:00 2001 From: Laura Trotta Date: Mon, 17 Jun 2024 11:39:35 +0200 Subject: [PATCH 10/10] autoconfigure test fix --- ...csearchVectorStoreAutoConfigurationIT.java | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java index 538ff072a6f..a9197295777 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java @@ -18,13 +18,12 @@ import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.ElasticsearchVectorStore; import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.SimilarityFunction; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.elasticsearch.ElasticsearchRestClientAutoConfiguration; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; @@ -51,8 +50,6 @@ class ElasticsearchVectorStoreAutoConfigurationIT { "docker.elastic.co/elasticsearch/elasticsearch:8.12.2") .withEnv("xpack.security.enabled", "false"); - private static final String DEFAULT = "default cosine similarity"; - private List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -65,21 +62,14 @@ class ElasticsearchVectorStoreAutoConfigurationIT { .withPropertyValues("spring.elasticsearch.uris=" + elasticsearchContainer.getHttpHostAddress(), "spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY")); - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) - public void addAndSearchTest(String similarityFunction) { + // No parametrized test based on similarity function, + // by default the bean will be created using cosine. + @Test + public void addAndSearchTest() { this.contextRunner.run(context -> { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } - vectorStore.add(documents); Awaitility.await() @@ -120,7 +110,7 @@ public void propertiesTest() { "spring.ai.vectorstore.elasticsearch.index-name=example", "spring.ai.vectorstore.elasticsearch.dimensions=1024", "spring.ai.vectorstore.elasticsearch.dense-vector-indexing=true", - "spring.ai.vectorstore.elasticsearch.similarity=dot_product") + "spring.ai.vectorstore.elasticsearch.similarity=cosine") .run(context -> { var properties = context.getBean(ElasticsearchVectorStoreProperties.class); var elasticsearchVectorStore = context.getBean(ElasticsearchVectorStore.class); @@ -128,8 +118,7 @@ public void propertiesTest() { assertThat(properties).isNotNull(); assertThat(properties.getIndexName()).isEqualTo("example"); assertThat(properties.getDimensions()).isEqualTo(1024); - assertThat(properties.isDenseVectorIndexing()).isTrue(); - assertThat(properties.getSimilarity()).isEqualTo("dot_product"); + assertThat(properties.getSimilarity()).isEqualTo(SimilarityFunction.cosine); assertThat(elasticsearchVectorStore).isNotNull(); });