diff --git a/pom.xml b/pom.xml index f9abde46034..4299c030c58 100644 --- a/pom.xml +++ b/pom.xml @@ -157,6 +157,7 @@ 0.1.4 2.20.11 42.7.2 + 8.13.3 2.3.4 0.8.0 2.0.46 diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc index 9fc6a895c09..a551c86d44c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc @@ -134,11 +134,18 @@ Properties starting with the `spring.ai.vectorstore.elasticsearch.*` prefix are |`spring.ai.vectorstore.elasticsearch.index-name` | The name of the index to store the vectors. | spring-ai-document-index |`spring.ai.vectorstore.elasticsearch.dimensions` | The number of dimensions in the vector. | 1536 -|`spring.ai.vectorstore.elasticsearch.dense-vector-indexing` | Whether to use dense vector indexing. | true |`spring.ai.vectorstore.elasticsearch.similarity` | The similarity function to use. | `cosine` |`spring.ai.vectorstore.elasticsearch.initialize-schema`| whether to initialize the required schema | `false` |=== +The following similarity functions are available: + +* cosine +* l2_norm +* dot_product + +More details about each in the https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html#dense-vector-params[Elasticsearch Documentation] on dense vectors. + == Metadata Filtering You can leverage the generic, portable xref:api/vectordbs.adoc#metadata-filters[metadata filters] with Elasticsearch as well. @@ -214,10 +221,11 @@ Read the link:https://www.elastic.co/guide/en/elasticsearch/client/java-api-clie ---- @Bean public RestClient restClient() { - RestClientBuilder builder = RestClient.builder(new HttpHost("", 9200, "http")); - Header[] defaultHeaders = new Header[] { new BasicHeader("Authorization", "Basic ") }; - builder.setDefaultHeaders(defaultHeaders); - return builder.build(); + RestClient.builder(new HttpHost("", 9200, "http")) + .setDefaultHeaders(new Header[]{ + new BasicHeader("Authorization", "Basic ") + }) + .build(); } ---- diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 84bd54a188a..5887b84db21 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -260,6 +260,7 @@ true + org.springframework.ai spring-ai-elasticsearch-store @@ -281,6 +282,12 @@ true + + co.elastic.clients + elasticsearch-java + ${elasticsearch-java.version} + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java index 78a8fe0fad5..1522abeae35 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java @@ -52,10 +52,7 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti if (properties.getDimensions() != null) { elasticsearchVectorStoreOptions.setDimensions(properties.getDimensions()); } - if (properties.isDenseVectorIndexing() != null) { - elasticsearchVectorStoreOptions.setDenseVectorIndexing(properties.isDenseVectorIndexing()); - } - if (StringUtils.hasText(properties.getSimilarity())) { + if (properties.getSimilarity() != null) { elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity()); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java index 67f677f6251..cf456b137d5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java @@ -16,6 +16,7 @@ package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; import org.springframework.ai.autoconfigure.CommonVectorStoreProperties; +import org.springframework.ai.vectorstore.SimilarityFunction; import org.springframework.boot.context.properties.ConfigurationProperties; /** @@ -37,15 +38,10 @@ public class ElasticsearchVectorStoreProperties extends CommonVectorStorePropert */ private Integer dimensions; - /** - * Whether to use dense vector indexing. - */ - private Boolean denseVectorIndexing; - /** * The similarity function to use. */ - private String similarity; + private SimilarityFunction similarity; public String getIndexName() { return this.indexName; @@ -63,19 +59,11 @@ public void setDimensions(Integer dimensions) { this.dimensions = dimensions; } - public Boolean isDenseVectorIndexing() { - return denseVectorIndexing; - } - - public void setDenseVectorIndexing(Boolean denseVectorIndexing) { - this.denseVectorIndexing = denseVectorIndexing; - } - - public String getSimilarity() { + public SimilarityFunction getSimilarity() { return similarity; } - public void setSimilarity(String similarity) { + public void setSimilarity(SimilarityFunction similarity) { this.similarity = similarity; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java index 538ff072a6f..a9197295777 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java @@ -18,13 +18,12 @@ import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.ElasticsearchVectorStore; import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.SimilarityFunction; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.elasticsearch.ElasticsearchRestClientAutoConfiguration; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; @@ -51,8 +50,6 @@ class ElasticsearchVectorStoreAutoConfigurationIT { "docker.elastic.co/elasticsearch/elasticsearch:8.12.2") .withEnv("xpack.security.enabled", "false"); - private static final String DEFAULT = "default cosine similarity"; - private List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -65,21 +62,14 @@ class ElasticsearchVectorStoreAutoConfigurationIT { .withPropertyValues("spring.elasticsearch.uris=" + elasticsearchContainer.getHttpHostAddress(), "spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY")); - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) - public void addAndSearchTest(String similarityFunction) { + // No parametrized test based on similarity function, + // by default the bean will be created using cosine. + @Test + public void addAndSearchTest() { this.contextRunner.run(context -> { ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } - vectorStore.add(documents); Awaitility.await() @@ -120,7 +110,7 @@ public void propertiesTest() { "spring.ai.vectorstore.elasticsearch.index-name=example", "spring.ai.vectorstore.elasticsearch.dimensions=1024", "spring.ai.vectorstore.elasticsearch.dense-vector-indexing=true", - "spring.ai.vectorstore.elasticsearch.similarity=dot_product") + "spring.ai.vectorstore.elasticsearch.similarity=cosine") .run(context -> { var properties = context.getBean(ElasticsearchVectorStoreProperties.class); var elasticsearchVectorStore = context.getBean(ElasticsearchVectorStore.class); @@ -128,8 +118,7 @@ public void propertiesTest() { assertThat(properties).isNotNull(); assertThat(properties.getIndexName()).isEqualTo("example"); assertThat(properties.getDimensions()).isEqualTo(1024); - assertThat(properties.isDenseVectorIndexing()).isTrue(); - assertThat(properties.getSimilarity()).isEqualTo("dot_product"); + assertThat(properties.getSimilarity()).isEqualTo(SimilarityFunction.cosine); assertThat(elasticsearchVectorStore).isNotNull(); }); diff --git a/vector-stores/spring-ai-elasticsearch-store/pom.xml b/vector-stores/spring-ai-elasticsearch-store/pom.xml index 67ac31ac7db..67ae9c969a8 100644 --- a/vector-stores/spring-ai-elasticsearch-store/pom.xml +++ b/vector-stores/spring-ai-elasticsearch-store/pom.xml @@ -35,6 +35,7 @@ co.elastic.clients elasticsearch-java + ${elasticsearch-java.version} @@ -45,7 +46,6 @@ test - org.springframework.ai spring-ai-test diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 8f672862cd5..e481dbaf2bf 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -16,17 +16,13 @@ package org.springframework.ai.vectorstore; import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty; -import co.elastic.clients.elasticsearch._types.mapping.Property; -import co.elastic.clients.elasticsearch._types.query_dsl.Query; import co.elastic.clients.elasticsearch.core.BulkRequest; import co.elastic.clients.elasticsearch.core.BulkResponse; +import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; import co.elastic.clients.elasticsearch.core.search.Hit; import co.elastic.clients.elasticsearch.indices.CreateIndexResponse; -import co.elastic.clients.json.JsonData; import co.elastic.clients.json.jackson.JacksonJsonpMapper; -import co.elastic.clients.transport.endpoints.BooleanResponse; import co.elastic.clients.transport.rest_client.RestClientTransport; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; @@ -46,16 +42,17 @@ import java.util.Optional; import java.util.stream.Collectors; +import static java.lang.Math.sqrt; +import static org.springframework.ai.vectorstore.SimilarityFunction.l2_norm; + /** * @author Jemin Huh * @author Wei Jiang + * @author Laura Trotta * @since 1.0.0 */ public class ElasticsearchVectorStore implements VectorStore, InitializingBean { - // divided by 2 to get score in the range [0, 1] - public static final String COSINE_SIMILARITY_FUNCTION = "(cosineSimilarity(params.query_vector, 'embedding') + 1.0) / 2"; - private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class); private final EmbeddingModel embeddingModel; @@ -66,8 +63,6 @@ public class ElasticsearchVectorStore implements VectorStore, InitializingBean { private final FilterExpressionConverter filterExpressionConverter; - private String similarityFunction; - private final boolean initializeSchema; public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) { @@ -84,30 +79,22 @@ public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestCli this.embeddingModel = embeddingModel; this.options = options; this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter(); - // the potential functions for vector fields at - // https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-script-score-query.html#vector-functions - this.similarityFunction = COSINE_SIMILARITY_FUNCTION; - } - - public ElasticsearchVectorStore withSimilarityFunction(String similarityFunction) { - this.similarityFunction = similarityFunction; - return this; } @Override public void add(List documents) { - BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder(); + BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); for (Document document : documents) { if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().isEmpty()) { logger.debug("Calling EmbeddingModel for document id = " + document.getId()); document.setEmbedding(this.embeddingModel.embed(document)); } - builkRequestBuilder.operations(op -> op + bulkRequestBuilder.operations(op -> op .index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(document))); } - BulkResponse bulkRequest = bulkRequest(builkRequestBuilder.build()); + BulkResponse bulkRequest = bulkRequest(bulkRequestBuilder.build()); if (bulkRequest.errors()) { List bulkResponseItems = bulkRequest.items(); @@ -121,10 +108,10 @@ public void add(List documents) { @Override public Optional delete(List idList) { - BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder(); + BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); for (String id : idList) - builkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id))); - return Optional.of(bulkRequest(builkRequestBuilder.build()).errors()); + bulkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id))); + return Optional.of(bulkRequest(bulkRequestBuilder.build()).errors()); } private BulkResponse bulkRequest(BulkRequest bulkRequest) { @@ -139,61 +126,67 @@ private BulkResponse bulkRequest(BulkRequest bulkRequest) { @Override public List similaritySearch(SearchRequest searchRequest) { Assert.notNull(searchRequest, "The search request must not be null."); - return similaritySearch(this.embeddingModel.embed(searchRequest.getQuery()), searchRequest.getTopK(), - Double.valueOf(searchRequest.getSimilarityThreshold()).floatValue(), - searchRequest.getFilterExpression()); - } - - public List similaritySearch(List embedding, int topK, double similarityThreshold, - Filter.Expression filterExpression) { - return similaritySearch( - new co.elastic.clients.elasticsearch.core.SearchRequest.Builder().index(options.getIndexName()) - .query(getElasticsearchSimilarityQuery(embedding, filterExpression)) - .size(topK) - .minScore(similarityThreshold) - .build()); - } - - private Query getElasticsearchSimilarityQuery(List embedding, Filter.Expression filterExpression) { - return Query.of(queryBuilder -> queryBuilder.scriptScore(scriptScoreQueryBuilder -> scriptScoreQueryBuilder - .query(queryBuilder2 -> queryBuilder2.queryString(queryStringQuerybuilder -> queryStringQuerybuilder - .query(getElasticsearchQueryString(filterExpression)))) - .script(scriptBuilder -> scriptBuilder - .inline(inlineScriptBuilder -> inlineScriptBuilder.source(this.similarityFunction) - .params("query_vector", JsonData.of(embedding)))))); - } - - private String getElasticsearchQueryString(Filter.Expression filterExpression) { - return Objects.isNull(filterExpression) ? "*" - : this.filterExpressionConverter.convertExpression(filterExpression); - - } - - private List similaritySearch(co.elastic.clients.elasticsearch.core.SearchRequest searchRequest) { try { - return this.elasticsearchClient.search(searchRequest, Document.class) - .hits() - .hits() + float threshold = (float) searchRequest.getSimilarityThreshold(); + // reverting l2_norm distance to its original value + if (options.getSimilarity().equals(l2_norm)) { + threshold = 1 - threshold; + } + final float finalThreshold = threshold; + List vectors = this.embeddingModel.embed(searchRequest.getQuery()) .stream() - .map(this::toDocument) - .collect(Collectors.toList()); + .map(Double::floatValue) + .toList(); + + SearchResponse res = elasticsearchClient.search( + sr -> sr.index(options.getIndexName()) + .knn(knn -> knn.queryVector(vectors) + .similarity(finalThreshold) + .k((long) searchRequest.getTopK()) + .field("embedding") + .numCandidates((long) (1.5 * searchRequest.getTopK())) + .filter(fl -> fl.queryString( + qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))), + Document.class); + + return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList()); } catch (IOException e) { throw new RuntimeException(e); } } + private String getElasticsearchQueryString(Filter.Expression filterExpression) { + return Objects.isNull(filterExpression) ? "*" + : this.filterExpressionConverter.convertExpression(filterExpression); + + } + private Document toDocument(Hit hit) { Document document = hit.source(); - document.getMetadata().put("distance", 1 - hit.score().floatValue()); + document.getMetadata().put("distance", calculateDistance(hit.score().floatValue())); return document; } - private boolean indexExists() { + // more info on score/distance calculation + // https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search + private float calculateDistance(Float score) { + switch (options.getSimilarity()) { + case l2_norm: + // the returned value of l2_norm is the opposite of the other functions + // (closest to zero means more accurate), so to make it consistent + // with the other functions the reverse is returned applying a "1-" + // to the standard transformation + return (float) (1 - (sqrt((1 / score) - 1))); + // cosine and dot_product + default: + return (2 * score) - 1; + } + } + + public boolean indexExists() { try { - BooleanResponse response = this.elasticsearchClient.indices() - .exists(existRequestBuilder -> existRequestBuilder.index(options.getIndexName())); - return response.value(); + return this.elasticsearchClient.indices().exists(ex -> ex.index(options.getIndexName())).value(); } catch (IOException e) { throw new RuntimeException(e); @@ -203,18 +196,9 @@ private boolean indexExists() { private CreateIndexResponse createIndexMapping() { try { return this.elasticsearchClient.indices() - .create(createIndexBuilder -> createIndexBuilder.index(options.getIndexName()) - .mappings(typeMappingBuilder -> { - typeMappingBuilder.properties("embedding", - new Property.Builder() - .denseVector(new DenseVectorProperty.Builder().dims(options.getDimensions()) - .similarity(options.getSimilarity()) - .index(options.isDenseVectorIndexing()) - .build()) - .build()); - - return typeMappingBuilder; - })); + .create(cr -> cr.index(options.getIndexName()) + .mappings(map -> map.properties("embedding", p -> p.denseVector( + dv -> dv.similarity(options.getSimilarity().toString()).dims(options.getDimensions()))))); } catch (IOException e) { throw new RuntimeException(e); @@ -233,4 +217,4 @@ public void afterPropertiesSet() { } } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java index 6cc7794f50a..c685122461c 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java @@ -34,15 +34,10 @@ public class ElasticsearchVectorStoreOptions { */ private int dimensions = 1536; - /** - * Whether to use dense vector indexing. - */ - private boolean denseVectorIndexing = true; - /** * The similarity function to use. */ - private String similarity = "cosine"; + private SimilarityFunction similarity = SimilarityFunction.cosine; public String getIndexName() { return indexName; @@ -60,19 +55,11 @@ public void setDimensions(int dims) { this.dimensions = dims; } - public boolean isDenseVectorIndexing() { - return denseVectorIndexing; - } - - public void setDenseVectorIndexing(boolean denseVectorIndexing) { - this.denseVectorIndexing = denseVectorIndexing; - } - - public String getSimilarity() { + public SimilarityFunction getSimilarity() { return similarity; } - public void setSimilarity(String similarity) { + public void setSimilarity(SimilarityFunction similarity) { this.similarity = similarity; } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java new file mode 100644 index 00000000000..86fc84c01c0 --- /dev/null +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java @@ -0,0 +1,15 @@ +package org.springframework.ai.vectorstore; + +/** + * https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html + * max_inner_product is currently not supported because the distance value is not + * normalized and would not comply with the requirement of being between 0 and 1 + * + * @author Laura Trotta + * @since 1.0.0 + */ +public enum SimilarityFunction { + + l2_norm, dot_product, cosine + +} diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index 350c121c431..0832042e5c1 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -25,6 +25,12 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord; +import co.elastic.clients.json.jackson.JacksonJsonpMapper; +import co.elastic.clients.transport.rest_client.RestClientTransport; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.http.HttpHost; import org.awaitility.Awaitility; import org.elasticsearch.client.RestClient; @@ -36,7 +42,6 @@ import org.testcontainers.elasticsearch.ElasticsearchContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -59,14 +64,10 @@ class ElasticsearchVectorStoreIT { @Container private static final ElasticsearchContainer elasticsearchContainer = new ElasticsearchContainer( - "docker.elastic.co/elasticsearch/elasticsearch:8.12.2") + "docker.elastic.co/elasticsearch/elasticsearch:8.13.3") .withEnv("xpack.security.enabled", "false"); - private static final String DEFAULT = "default cosine similarity"; - - protected final ObjectMapper objectMapper = new ObjectMapper(); - - private List documents = List.of( + private final List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); @@ -95,35 +96,33 @@ private ApplicationContextRunner getContextRunner() { @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); - vectorStore.delete(List.of("_all")); + // deleting indices and data before following tests + ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class); + List indices = elasticsearchClient.cat().indices().valueBody().stream().map(IndicesRecord::index).toList(); + if (!indices.isEmpty()) { + elasticsearchClient.indices().delete(del -> del.index(indices)); + } }); } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void addAndSearchTest(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, + ElasticsearchVectorStore.class); vectorStore.add(documents); Awaitility.await() .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll()), hasSize(1)); List results = vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)); + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll()); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -138,25 +137,18 @@ public void addAndSearchTest(String similarityFunction) { Awaitility.await() .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll()), hasSize(0)); }); } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void searchWithFilters(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, + ElasticsearchVectorStore.class); var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner", Map.of("country", "BG", "year", 2020, "activationDate", new Date(1000))); @@ -168,7 +160,9 @@ public void searchWithFilters(String similarityFunction) { vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); Awaitility.await() - .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)), hasSize(3)); + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("The World").withTopK(5).withSimilarityThresholdAll()), + hasSize(3)); List results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) @@ -246,18 +240,12 @@ public void searchWithFilters(String similarityFunction) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void documentUpdateTest(String similarityFunction) { getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, + ElasticsearchVectorStore.class); Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", Map.of("meta1", "meta1")); @@ -265,11 +253,11 @@ public void documentUpdateTest(String similarityFunction) { Awaitility.await() .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(0).withTopK(5)), + .similaritySearch(SearchRequest.query("Spring").withSimilarityThresholdAll().withTopK(5)), hasSize(1)); List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(0).withTopK(5)); + .similaritySearch(SearchRequest.query("Spring").withSimilarityThresholdAll().withTopK(5)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -282,7 +270,7 @@ public void documentUpdateTest(String similarityFunction) { "The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2")); vectorStore.add(List.of(sameIdDocument)); - SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5); + SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5).withSimilarityThresholdAll(); Awaitility.await() .until(() -> vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(), @@ -306,24 +294,15 @@ public void documentUpdateTest(String similarityFunction) { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { DEFAULT, """ - double value = dotProduct(params.query_vector, 'embedding'); - return sigmoid(1, Math.E, -value); - """, "1 / (1 + l1norm(params.query_vector, 'embedding'))", - "1 / (1 + l2norm(params.query_vector, 'embedding'))" }) + @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) public void searchThresholdTest(String similarityFunction) { - getContextRunner().run(context -> { - ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class); - if (!DEFAULT.equals(similarityFunction)) { - vectorStore.withSimilarityFunction(similarityFunction); - } + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, + ElasticsearchVectorStore.class); vectorStore.add(documents); - SearchRequest query = SearchRequest.query("Great Depression") - .withTopK(50) - .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); + SearchRequest query = SearchRequest.query("Great Depression").withTopK(50).withSimilarityThresholdAll(); Awaitility.await().until(() -> vectorStore.similaritySearch(query), hasSize(3)); @@ -333,10 +312,10 @@ public void searchThresholdTest(String similarityFunction) { assertThat(distances).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + float thresholdResult = (distances.get(0) + distances.get(1)) / 2; List results = vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(1 - threshold)); + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(thresholdResult)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -349,9 +328,8 @@ public void searchThresholdTest(String similarityFunction) { vectorStore.delete(documents.stream().map(Document::getId).toList()); Awaitility.await() - .until(() -> vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(0)), - hasSize(0)); + .until(() -> vectorStore.similaritySearch( + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThresholdAll()), hasSize(0)); }); } @@ -359,11 +337,25 @@ public void searchThresholdTest(String similarityFunction) { @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication { - @Bean - public ElasticsearchVectorStore vectorStore(EmbeddingModel embeddingModel) { - return new ElasticsearchVectorStore( - RestClient.builder(HttpHost.create(elasticsearchContainer.getHttpHostAddress())).build(), - embeddingModel, true); + @Bean("vectorStore_cosine") + public ElasticsearchVectorStore vectorStoreDefault(EmbeddingModel embeddingModel, RestClient restClient) { + return new ElasticsearchVectorStore(restClient, embeddingModel, true); + } + + @Bean("vectorStore_l2_norm") + public ElasticsearchVectorStore vectorStoreL2(EmbeddingModel embeddingModel, RestClient restClient) { + ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); + options.setIndexName("index_l2"); + options.setSimilarity(SimilarityFunction.l2_norm); + return new ElasticsearchVectorStore(options, restClient, embeddingModel,true); + } + + @Bean("vectorStore_dot_product") + public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingModel, RestClient restClient) { + ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); + options.setIndexName("index_dot_product"); + options.setSimilarity(SimilarityFunction.dot_product); + return new ElasticsearchVectorStore(options, restClient, embeddingModel,true); } @Bean @@ -371,6 +363,17 @@ public EmbeddingModel embeddingModel() { return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); } + @Bean + RestClient restClient() { + return RestClient.builder(HttpHost.create(elasticsearchContainer.getHttpHostAddress())).build(); + } + + @Bean + ElasticsearchClient elasticsearchClient(RestClient restClient) { + return new ElasticsearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper( + new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)))); + } + } }