diff --git a/pom.xml b/pom.xml
index f9abde46034..4299c030c58 100644
--- a/pom.xml
+++ b/pom.xml
@@ -157,6 +157,7 @@
0.1.4
2.20.11
42.7.2
+ 8.13.3
2.3.4
0.8.0
2.0.46
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc
index 9fc6a895c09..a551c86d44c 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/elasticsearch.adoc
@@ -134,11 +134,18 @@ Properties starting with the `spring.ai.vectorstore.elasticsearch.*` prefix are
|`spring.ai.vectorstore.elasticsearch.index-name` | The name of the index to store the vectors. | spring-ai-document-index
|`spring.ai.vectorstore.elasticsearch.dimensions` | The number of dimensions in the vector. | 1536
-|`spring.ai.vectorstore.elasticsearch.dense-vector-indexing` | Whether to use dense vector indexing. | true
|`spring.ai.vectorstore.elasticsearch.similarity` | The similarity function to use. | `cosine`
|`spring.ai.vectorstore.elasticsearch.initialize-schema`| whether to initialize the required schema | `false`
|===
+The following similarity functions are available:
+
+* cosine
+* l2_norm
+* dot_product
+
+More details about each in the https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html#dense-vector-params[Elasticsearch Documentation] on dense vectors.
+
== Metadata Filtering
You can leverage the generic, portable xref:api/vectordbs.adoc#metadata-filters[metadata filters] with Elasticsearch as well.
@@ -214,10 +221,11 @@ Read the link:https://www.elastic.co/guide/en/elasticsearch/client/java-api-clie
----
@Bean
public RestClient restClient() {
- RestClientBuilder builder = RestClient.builder(new HttpHost("", 9200, "http"));
- Header[] defaultHeaders = new Header[] { new BasicHeader("Authorization", "Basic ") };
- builder.setDefaultHeaders(defaultHeaders);
- return builder.build();
+ RestClient.builder(new HttpHost("", 9200, "http"))
+ .setDefaultHeaders(new Header[]{
+ new BasicHeader("Authorization", "Basic ")
+ })
+ .build();
}
----
diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml
index 84bd54a188a..5887b84db21 100644
--- a/spring-ai-spring-boot-autoconfigure/pom.xml
+++ b/spring-ai-spring-boot-autoconfigure/pom.xml
@@ -260,6 +260,7 @@
true
+
org.springframework.ai
spring-ai-elasticsearch-store
@@ -281,6 +282,12 @@
true
+
+ co.elastic.clients
+ elasticsearch-java
+ ${elasticsearch-java.version}
+
+
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java
index 78a8fe0fad5..1522abeae35 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java
@@ -52,10 +52,7 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti
if (properties.getDimensions() != null) {
elasticsearchVectorStoreOptions.setDimensions(properties.getDimensions());
}
- if (properties.isDenseVectorIndexing() != null) {
- elasticsearchVectorStoreOptions.setDenseVectorIndexing(properties.isDenseVectorIndexing());
- }
- if (StringUtils.hasText(properties.getSimilarity())) {
+ if (properties.getSimilarity() != null) {
elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity());
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java
index 67f677f6251..cf456b137d5 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreProperties.java
@@ -16,6 +16,7 @@
package org.springframework.ai.autoconfigure.vectorstore.elasticsearch;
import org.springframework.ai.autoconfigure.CommonVectorStoreProperties;
+import org.springframework.ai.vectorstore.SimilarityFunction;
import org.springframework.boot.context.properties.ConfigurationProperties;
/**
@@ -37,15 +38,10 @@ public class ElasticsearchVectorStoreProperties extends CommonVectorStorePropert
*/
private Integer dimensions;
- /**
- * Whether to use dense vector indexing.
- */
- private Boolean denseVectorIndexing;
-
/**
* The similarity function to use.
*/
- private String similarity;
+ private SimilarityFunction similarity;
public String getIndexName() {
return this.indexName;
@@ -63,19 +59,11 @@ public void setDimensions(Integer dimensions) {
this.dimensions = dimensions;
}
- public Boolean isDenseVectorIndexing() {
- return denseVectorIndexing;
- }
-
- public void setDenseVectorIndexing(Boolean denseVectorIndexing) {
- this.denseVectorIndexing = denseVectorIndexing;
- }
-
- public String getSimilarity() {
+ public SimilarityFunction getSimilarity() {
return similarity;
}
- public void setSimilarity(String similarity) {
+ public void setSimilarity(SimilarityFunction similarity) {
this.similarity = similarity;
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java
index 538ff072a6f..a9197295777 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfigurationIT.java
@@ -18,13 +18,12 @@
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.ai.vectorstore.SimilarityFunction;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.elasticsearch.ElasticsearchRestClientAutoConfiguration;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
@@ -51,8 +50,6 @@ class ElasticsearchVectorStoreAutoConfigurationIT {
"docker.elastic.co/elasticsearch/elasticsearch:8.12.2")
.withEnv("xpack.security.enabled", "false");
- private static final String DEFAULT = "default cosine similarity";
-
private List documents = List.of(
new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
@@ -65,21 +62,14 @@ class ElasticsearchVectorStoreAutoConfigurationIT {
.withPropertyValues("spring.elasticsearch.uris=" + elasticsearchContainer.getHttpHostAddress(),
"spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY"));
- @ParameterizedTest(name = "{0} : {displayName} ")
- @ValueSource(strings = { DEFAULT, """
- double value = dotProduct(params.query_vector, 'embedding');
- return sigmoid(1, Math.E, -value);
- """, "1 / (1 + l1norm(params.query_vector, 'embedding'))",
- "1 / (1 + l2norm(params.query_vector, 'embedding'))" })
- public void addAndSearchTest(String similarityFunction) {
+ // No parametrized test based on similarity function,
+ // by default the bean will be created using cosine.
+ @Test
+ public void addAndSearchTest() {
this.contextRunner.run(context -> {
ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class);
- if (!DEFAULT.equals(similarityFunction)) {
- vectorStore.withSimilarityFunction(similarityFunction);
- }
-
vectorStore.add(documents);
Awaitility.await()
@@ -120,7 +110,7 @@ public void propertiesTest() {
"spring.ai.vectorstore.elasticsearch.index-name=example",
"spring.ai.vectorstore.elasticsearch.dimensions=1024",
"spring.ai.vectorstore.elasticsearch.dense-vector-indexing=true",
- "spring.ai.vectorstore.elasticsearch.similarity=dot_product")
+ "spring.ai.vectorstore.elasticsearch.similarity=cosine")
.run(context -> {
var properties = context.getBean(ElasticsearchVectorStoreProperties.class);
var elasticsearchVectorStore = context.getBean(ElasticsearchVectorStore.class);
@@ -128,8 +118,7 @@ public void propertiesTest() {
assertThat(properties).isNotNull();
assertThat(properties.getIndexName()).isEqualTo("example");
assertThat(properties.getDimensions()).isEqualTo(1024);
- assertThat(properties.isDenseVectorIndexing()).isTrue();
- assertThat(properties.getSimilarity()).isEqualTo("dot_product");
+ assertThat(properties.getSimilarity()).isEqualTo(SimilarityFunction.cosine);
assertThat(elasticsearchVectorStore).isNotNull();
});
diff --git a/vector-stores/spring-ai-elasticsearch-store/pom.xml b/vector-stores/spring-ai-elasticsearch-store/pom.xml
index 67ac31ac7db..67ae9c969a8 100644
--- a/vector-stores/spring-ai-elasticsearch-store/pom.xml
+++ b/vector-stores/spring-ai-elasticsearch-store/pom.xml
@@ -35,6 +35,7 @@
co.elastic.clients
elasticsearch-java
+ ${elasticsearch-java.version}
@@ -45,7 +46,6 @@
test
-
org.springframework.ai
spring-ai-test
diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java
index 8f672862cd5..e481dbaf2bf 100644
--- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java
+++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java
@@ -16,17 +16,13 @@
package org.springframework.ai.vectorstore;
import co.elastic.clients.elasticsearch.ElasticsearchClient;
-import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty;
-import co.elastic.clients.elasticsearch._types.mapping.Property;
-import co.elastic.clients.elasticsearch._types.query_dsl.Query;
import co.elastic.clients.elasticsearch.core.BulkRequest;
import co.elastic.clients.elasticsearch.core.BulkResponse;
+import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
import co.elastic.clients.elasticsearch.core.search.Hit;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
-import co.elastic.clients.json.JsonData;
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
-import co.elastic.clients.transport.endpoints.BooleanResponse;
import co.elastic.clients.transport.rest_client.RestClientTransport;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -46,16 +42,17 @@
import java.util.Optional;
import java.util.stream.Collectors;
+import static java.lang.Math.sqrt;
+import static org.springframework.ai.vectorstore.SimilarityFunction.l2_norm;
+
/**
* @author Jemin Huh
* @author Wei Jiang
+ * @author Laura Trotta
* @since 1.0.0
*/
public class ElasticsearchVectorStore implements VectorStore, InitializingBean {
- // divided by 2 to get score in the range [0, 1]
- public static final String COSINE_SIMILARITY_FUNCTION = "(cosineSimilarity(params.query_vector, 'embedding') + 1.0) / 2";
-
private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class);
private final EmbeddingModel embeddingModel;
@@ -66,8 +63,6 @@ public class ElasticsearchVectorStore implements VectorStore, InitializingBean {
private final FilterExpressionConverter filterExpressionConverter;
- private String similarityFunction;
-
private final boolean initializeSchema;
public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) {
@@ -84,30 +79,22 @@ public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestCli
this.embeddingModel = embeddingModel;
this.options = options;
this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();
- // the potential functions for vector fields at
- // https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-script-score-query.html#vector-functions
- this.similarityFunction = COSINE_SIMILARITY_FUNCTION;
- }
-
- public ElasticsearchVectorStore withSimilarityFunction(String similarityFunction) {
- this.similarityFunction = similarityFunction;
- return this;
}
@Override
public void add(List documents) {
- BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder();
+ BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder();
for (Document document : documents) {
if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().isEmpty()) {
logger.debug("Calling EmbeddingModel for document id = " + document.getId());
document.setEmbedding(this.embeddingModel.embed(document));
}
- builkRequestBuilder.operations(op -> op
+ bulkRequestBuilder.operations(op -> op
.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(document)));
}
- BulkResponse bulkRequest = bulkRequest(builkRequestBuilder.build());
+ BulkResponse bulkRequest = bulkRequest(bulkRequestBuilder.build());
if (bulkRequest.errors()) {
List bulkResponseItems = bulkRequest.items();
@@ -121,10 +108,10 @@ public void add(List documents) {
@Override
public Optional delete(List idList) {
- BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder();
+ BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder();
for (String id : idList)
- builkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id)));
- return Optional.of(bulkRequest(builkRequestBuilder.build()).errors());
+ bulkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id)));
+ return Optional.of(bulkRequest(bulkRequestBuilder.build()).errors());
}
private BulkResponse bulkRequest(BulkRequest bulkRequest) {
@@ -139,61 +126,67 @@ private BulkResponse bulkRequest(BulkRequest bulkRequest) {
@Override
public List similaritySearch(SearchRequest searchRequest) {
Assert.notNull(searchRequest, "The search request must not be null.");
- return similaritySearch(this.embeddingModel.embed(searchRequest.getQuery()), searchRequest.getTopK(),
- Double.valueOf(searchRequest.getSimilarityThreshold()).floatValue(),
- searchRequest.getFilterExpression());
- }
-
- public List similaritySearch(List embedding, int topK, double similarityThreshold,
- Filter.Expression filterExpression) {
- return similaritySearch(
- new co.elastic.clients.elasticsearch.core.SearchRequest.Builder().index(options.getIndexName())
- .query(getElasticsearchSimilarityQuery(embedding, filterExpression))
- .size(topK)
- .minScore(similarityThreshold)
- .build());
- }
-
- private Query getElasticsearchSimilarityQuery(List embedding, Filter.Expression filterExpression) {
- return Query.of(queryBuilder -> queryBuilder.scriptScore(scriptScoreQueryBuilder -> scriptScoreQueryBuilder
- .query(queryBuilder2 -> queryBuilder2.queryString(queryStringQuerybuilder -> queryStringQuerybuilder
- .query(getElasticsearchQueryString(filterExpression))))
- .script(scriptBuilder -> scriptBuilder
- .inline(inlineScriptBuilder -> inlineScriptBuilder.source(this.similarityFunction)
- .params("query_vector", JsonData.of(embedding))))));
- }
-
- private String getElasticsearchQueryString(Filter.Expression filterExpression) {
- return Objects.isNull(filterExpression) ? "*"
- : this.filterExpressionConverter.convertExpression(filterExpression);
-
- }
-
- private List similaritySearch(co.elastic.clients.elasticsearch.core.SearchRequest searchRequest) {
try {
- return this.elasticsearchClient.search(searchRequest, Document.class)
- .hits()
- .hits()
+ float threshold = (float) searchRequest.getSimilarityThreshold();
+ // reverting l2_norm distance to its original value
+ if (options.getSimilarity().equals(l2_norm)) {
+ threshold = 1 - threshold;
+ }
+ final float finalThreshold = threshold;
+ List vectors = this.embeddingModel.embed(searchRequest.getQuery())
.stream()
- .map(this::toDocument)
- .collect(Collectors.toList());
+ .map(Double::floatValue)
+ .toList();
+
+ SearchResponse res = elasticsearchClient.search(
+ sr -> sr.index(options.getIndexName())
+ .knn(knn -> knn.queryVector(vectors)
+ .similarity(finalThreshold)
+ .k((long) searchRequest.getTopK())
+ .field("embedding")
+ .numCandidates((long) (1.5 * searchRequest.getTopK()))
+ .filter(fl -> fl.queryString(
+ qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))),
+ Document.class);
+
+ return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList());
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
+ private String getElasticsearchQueryString(Filter.Expression filterExpression) {
+ return Objects.isNull(filterExpression) ? "*"
+ : this.filterExpressionConverter.convertExpression(filterExpression);
+
+ }
+
private Document toDocument(Hit hit) {
Document document = hit.source();
- document.getMetadata().put("distance", 1 - hit.score().floatValue());
+ document.getMetadata().put("distance", calculateDistance(hit.score().floatValue()));
return document;
}
- private boolean indexExists() {
+ // more info on score/distance calculation
+ // https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search
+ private float calculateDistance(Float score) {
+ switch (options.getSimilarity()) {
+ case l2_norm:
+ // the returned value of l2_norm is the opposite of the other functions
+ // (closest to zero means more accurate), so to make it consistent
+ // with the other functions the reverse is returned applying a "1-"
+ // to the standard transformation
+ return (float) (1 - (sqrt((1 / score) - 1)));
+ // cosine and dot_product
+ default:
+ return (2 * score) - 1;
+ }
+ }
+
+ public boolean indexExists() {
try {
- BooleanResponse response = this.elasticsearchClient.indices()
- .exists(existRequestBuilder -> existRequestBuilder.index(options.getIndexName()));
- return response.value();
+ return this.elasticsearchClient.indices().exists(ex -> ex.index(options.getIndexName())).value();
}
catch (IOException e) {
throw new RuntimeException(e);
@@ -203,18 +196,9 @@ private boolean indexExists() {
private CreateIndexResponse createIndexMapping() {
try {
return this.elasticsearchClient.indices()
- .create(createIndexBuilder -> createIndexBuilder.index(options.getIndexName())
- .mappings(typeMappingBuilder -> {
- typeMappingBuilder.properties("embedding",
- new Property.Builder()
- .denseVector(new DenseVectorProperty.Builder().dims(options.getDimensions())
- .similarity(options.getSimilarity())
- .index(options.isDenseVectorIndexing())
- .build())
- .build());
-
- return typeMappingBuilder;
- }));
+ .create(cr -> cr.index(options.getIndexName())
+ .mappings(map -> map.properties("embedding", p -> p.denseVector(
+ dv -> dv.similarity(options.getSimilarity().toString()).dims(options.getDimensions())))));
}
catch (IOException e) {
throw new RuntimeException(e);
@@ -233,4 +217,4 @@ public void afterPropertiesSet() {
}
}
-}
\ No newline at end of file
+}
diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java
index 6cc7794f50a..c685122461c 100644
--- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java
+++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreOptions.java
@@ -34,15 +34,10 @@ public class ElasticsearchVectorStoreOptions {
*/
private int dimensions = 1536;
- /**
- * Whether to use dense vector indexing.
- */
- private boolean denseVectorIndexing = true;
-
/**
* The similarity function to use.
*/
- private String similarity = "cosine";
+ private SimilarityFunction similarity = SimilarityFunction.cosine;
public String getIndexName() {
return indexName;
@@ -60,19 +55,11 @@ public void setDimensions(int dims) {
this.dimensions = dims;
}
- public boolean isDenseVectorIndexing() {
- return denseVectorIndexing;
- }
-
- public void setDenseVectorIndexing(boolean denseVectorIndexing) {
- this.denseVectorIndexing = denseVectorIndexing;
- }
-
- public String getSimilarity() {
+ public SimilarityFunction getSimilarity() {
return similarity;
}
- public void setSimilarity(String similarity) {
+ public void setSimilarity(SimilarityFunction similarity) {
this.similarity = similarity;
}
diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java
new file mode 100644
index 00000000000..86fc84c01c0
--- /dev/null
+++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/SimilarityFunction.java
@@ -0,0 +1,15 @@
+package org.springframework.ai.vectorstore;
+
+/**
+ * https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html
+ * max_inner_product is currently not supported because the distance value is not
+ * normalized and would not comply with the requirement of being between 0 and 1
+ *
+ * @author Laura Trotta
+ * @since 1.0.0
+ */
+public enum SimilarityFunction {
+
+ l2_norm, dot_product, cosine
+
+}
diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java
index 350c121c431..0832042e5c1 100644
--- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java
+++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java
@@ -25,6 +25,12 @@
import java.util.UUID;
import java.util.concurrent.TimeUnit;
+import co.elastic.clients.elasticsearch.ElasticsearchClient;
+import co.elastic.clients.elasticsearch.cat.indices.IndicesRecord;
+import co.elastic.clients.json.jackson.JacksonJsonpMapper;
+import co.elastic.clients.transport.rest_client.RestClientTransport;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.http.HttpHost;
import org.awaitility.Awaitility;
import org.elasticsearch.client.RestClient;
@@ -36,7 +42,6 @@
import org.testcontainers.elasticsearch.ElasticsearchContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
-import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
@@ -59,14 +64,10 @@ class ElasticsearchVectorStoreIT {
@Container
private static final ElasticsearchContainer elasticsearchContainer = new ElasticsearchContainer(
- "docker.elastic.co/elasticsearch/elasticsearch:8.12.2")
+ "docker.elastic.co/elasticsearch/elasticsearch:8.13.3")
.withEnv("xpack.security.enabled", "false");
- private static final String DEFAULT = "default cosine similarity";
-
- protected final ObjectMapper objectMapper = new ObjectMapper();
-
- private List documents = List.of(
+ private final List documents = List.of(
new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
@@ -95,35 +96,33 @@ private ApplicationContextRunner getContextRunner() {
@BeforeEach
void cleanDatabase() {
getContextRunner().run(context -> {
- VectorStore vectorStore = context.getBean(VectorStore.class);
- vectorStore.delete(List.of("_all"));
+ // deleting indices and data before following tests
+ ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);
+ List indices = elasticsearchClient.cat().indices().valueBody().stream().map(IndicesRecord::index).toList();
+ if (!indices.isEmpty()) {
+ elasticsearchClient.indices().delete(del -> del.index(indices));
+ }
});
}
@ParameterizedTest(name = "{0} : {displayName} ")
- @ValueSource(strings = { DEFAULT, """
- double value = dotProduct(params.query_vector, 'embedding');
- return sigmoid(1, Math.E, -value);
- """, "1 / (1 + l1norm(params.query_vector, 'embedding'))",
- "1 / (1 + l2norm(params.query_vector, 'embedding'))" })
+ @ValueSource(strings = { "cosine", "l2_norm", "dot_product" })
public void addAndSearchTest(String similarityFunction) {
getContextRunner().run(context -> {
- ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class);
- if (!DEFAULT.equals(similarityFunction)) {
- vectorStore.withSimilarityFunction(similarityFunction);
- }
+ ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction,
+ ElasticsearchVectorStore.class);
vectorStore.add(documents);
Awaitility.await()
.until(() -> vectorStore
- .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
+ .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll()),
hasSize(1));
List results = vectorStore
- .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0));
+ .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll());
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
@@ -138,25 +137,18 @@ public void addAndSearchTest(String similarityFunction) {
Awaitility.await()
.until(() -> vectorStore
- .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
+ .similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThresholdAll()),
hasSize(0));
});
}
@ParameterizedTest(name = "{0} : {displayName} ")
- @ValueSource(strings = { DEFAULT, """
- double value = dotProduct(params.query_vector, 'embedding');
- return sigmoid(1, Math.E, -value);
- """, "1 / (1 + l1norm(params.query_vector, 'embedding'))",
- "1 / (1 + l2norm(params.query_vector, 'embedding'))" })
+ @ValueSource(strings = { "cosine", "l2_norm", "dot_product" })
public void searchWithFilters(String similarityFunction) {
getContextRunner().run(context -> {
- ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class);
-
- if (!DEFAULT.equals(similarityFunction)) {
- vectorStore.withSimilarityFunction(similarityFunction);
- }
+ ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction,
+ ElasticsearchVectorStore.class);
var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner",
Map.of("country", "BG", "year", 2020, "activationDate", new Date(1000)));
@@ -168,7 +160,9 @@ public void searchWithFilters(String similarityFunction) {
vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2));
Awaitility.await()
- .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)), hasSize(3));
+ .until(() -> vectorStore
+ .similaritySearch(SearchRequest.query("The World").withTopK(5).withSimilarityThresholdAll()),
+ hasSize(3));
List results = vectorStore.similaritySearch(SearchRequest.query("The World")
.withTopK(5)
@@ -246,18 +240,12 @@ public void searchWithFilters(String similarityFunction) {
}
@ParameterizedTest(name = "{0} : {displayName} ")
- @ValueSource(strings = { DEFAULT, """
- double value = dotProduct(params.query_vector, 'embedding');
- return sigmoid(1, Math.E, -value);
- """, "1 / (1 + l1norm(params.query_vector, 'embedding'))",
- "1 / (1 + l2norm(params.query_vector, 'embedding'))" })
+ @ValueSource(strings = { "cosine", "l2_norm", "dot_product" })
public void documentUpdateTest(String similarityFunction) {
getContextRunner().run(context -> {
- ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class);
- if (!DEFAULT.equals(similarityFunction)) {
- vectorStore.withSimilarityFunction(similarityFunction);
- }
+ ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction,
+ ElasticsearchVectorStore.class);
Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!",
Map.of("meta1", "meta1"));
@@ -265,11 +253,11 @@ public void documentUpdateTest(String similarityFunction) {
Awaitility.await()
.until(() -> vectorStore
- .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(0).withTopK(5)),
+ .similaritySearch(SearchRequest.query("Spring").withSimilarityThresholdAll().withTopK(5)),
hasSize(1));
List results = vectorStore
- .similaritySearch(SearchRequest.query("Spring").withSimilarityThreshold(0).withTopK(5));
+ .similaritySearch(SearchRequest.query("Spring").withSimilarityThresholdAll().withTopK(5));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
@@ -282,7 +270,7 @@ public void documentUpdateTest(String similarityFunction) {
"The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2"));
vectorStore.add(List.of(sameIdDocument));
- SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5);
+ SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5).withSimilarityThresholdAll();
Awaitility.await()
.until(() -> vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(),
@@ -306,24 +294,15 @@ public void documentUpdateTest(String similarityFunction) {
}
@ParameterizedTest(name = "{0} : {displayName} ")
- @ValueSource(strings = { DEFAULT, """
- double value = dotProduct(params.query_vector, 'embedding');
- return sigmoid(1, Math.E, -value);
- """, "1 / (1 + l1norm(params.query_vector, 'embedding'))",
- "1 / (1 + l2norm(params.query_vector, 'embedding'))" })
+ @ValueSource(strings = { "cosine", "l2_norm", "dot_product" })
public void searchThresholdTest(String similarityFunction) {
-
getContextRunner().run(context -> {
- ElasticsearchVectorStore vectorStore = context.getBean(ElasticsearchVectorStore.class);
- if (!DEFAULT.equals(similarityFunction)) {
- vectorStore.withSimilarityFunction(similarityFunction);
- }
+ ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction,
+ ElasticsearchVectorStore.class);
vectorStore.add(documents);
- SearchRequest query = SearchRequest.query("Great Depression")
- .withTopK(50)
- .withSimilarityThreshold(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL);
+ SearchRequest query = SearchRequest.query("Great Depression").withTopK(50).withSimilarityThresholdAll();
Awaitility.await().until(() -> vectorStore.similaritySearch(query), hasSize(3));
@@ -333,10 +312,10 @@ public void searchThresholdTest(String similarityFunction) {
assertThat(distances).hasSize(3);
- float threshold = (distances.get(0) + distances.get(1)) / 2;
+ float thresholdResult = (distances.get(0) + distances.get(1)) / 2;
List results = vectorStore.similaritySearch(
- SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(1 - threshold));
+ SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(thresholdResult));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
@@ -349,9 +328,8 @@ public void searchThresholdTest(String similarityFunction) {
vectorStore.delete(documents.stream().map(Document::getId).toList());
Awaitility.await()
- .until(() -> vectorStore
- .similaritySearch(SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(0)),
- hasSize(0));
+ .until(() -> vectorStore.similaritySearch(
+ SearchRequest.query("Great Depression").withTopK(50).withSimilarityThresholdAll()), hasSize(0));
});
}
@@ -359,11 +337,25 @@ public void searchThresholdTest(String similarityFunction) {
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
public static class TestApplication {
- @Bean
- public ElasticsearchVectorStore vectorStore(EmbeddingModel embeddingModel) {
- return new ElasticsearchVectorStore(
- RestClient.builder(HttpHost.create(elasticsearchContainer.getHttpHostAddress())).build(),
- embeddingModel, true);
+ @Bean("vectorStore_cosine")
+ public ElasticsearchVectorStore vectorStoreDefault(EmbeddingModel embeddingModel, RestClient restClient) {
+ return new ElasticsearchVectorStore(restClient, embeddingModel, true);
+ }
+
+ @Bean("vectorStore_l2_norm")
+ public ElasticsearchVectorStore vectorStoreL2(EmbeddingModel embeddingModel, RestClient restClient) {
+ ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
+ options.setIndexName("index_l2");
+ options.setSimilarity(SimilarityFunction.l2_norm);
+ return new ElasticsearchVectorStore(options, restClient, embeddingModel,true);
+ }
+
+ @Bean("vectorStore_dot_product")
+ public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingModel, RestClient restClient) {
+ ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
+ options.setIndexName("index_dot_product");
+ options.setSimilarity(SimilarityFunction.dot_product);
+ return new ElasticsearchVectorStore(options, restClient, embeddingModel,true);
}
@Bean
@@ -371,6 +363,17 @@ public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
}
+ @Bean
+ RestClient restClient() {
+ return RestClient.builder(HttpHost.create(elasticsearchContainer.getHttpHostAddress())).build();
+ }
+
+ @Bean
+ ElasticsearchClient elasticsearchClient(RestClient restClient) {
+ return new ElasticsearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(
+ new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false))));
+ }
+
}
}