Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
import org.elasticsearch.client.RestClient;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.elasticsearch.ElasticsearchClientAutoConfiguration;
import org.springframework.boot.autoconfigure.elasticsearch.ElasticsearchRestClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.util.StringUtils;

/**
* @author Eddú Meléndez
* @author Wei Jiang
* @since 1.0.0
*/
@AutoConfiguration(after = ElasticsearchRestClientAutoConfiguration.class)
Expand All @@ -40,10 +41,22 @@ class ElasticsearchVectorStoreAutoConfiguration {
@ConditionalOnMissingBean
ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properties, RestClient restClient,
EmbeddingClient embeddingClient) {
ElasticsearchVectorStoreOptions elasticsearchVectorStoreOptions = new ElasticsearchVectorStoreOptions();

if (StringUtils.hasText(properties.getIndexName())) {
return new ElasticsearchVectorStore(properties.getIndexName(), restClient, embeddingClient);
elasticsearchVectorStoreOptions.setIndexName(properties.getIndexName());
}
if (properties.getDims() != null) {
elasticsearchVectorStoreOptions.setDims(properties.getDims());
}
if (properties.isDenseVectorIndexing() != null) {
elasticsearchVectorStoreOptions.setDenseVectorIndexing(properties.isDenseVectorIndexing());
}
return new ElasticsearchVectorStore(restClient, embeddingClient);
if (StringUtils.hasText(properties.getSimilarity())) {
elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity());
}

return new ElasticsearchVectorStore(elasticsearchVectorStoreOptions, restClient, embeddingClient);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@

/**
* @author Eddú Meléndez
* @author Wei Jiang
* @since 1.0.0
*/
@ConfigurationProperties(prefix = "spring.ai.vectorstore.elasticsearch")
public class ElasticsearchVectorStoreProperties {

private String indexName;

private Integer dims;

private Boolean denseVectorIndexing;

private String similarity;

public String getIndexName() {
return this.indexName;
}
Expand All @@ -34,4 +41,28 @@ public void setIndexName(String indexName) {
this.indexName = indexName;
}

public Integer getDims() {
return dims;
}

public void setDims(Integer dims) {
this.dims = dims;
}

public Boolean isDenseVectorIndexing() {
return denseVectorIndexing;
}

public void setDenseVectorIndexing(Boolean denseVectorIndexing) {
this.denseVectorIndexing = denseVectorIndexing;
}

public String getSimilarity() {
return similarity;
}

public void setSimilarity(String similarity) {
this.similarity = similarity;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.ai.autoconfigure.vectorstore.elasticsearch;

import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
Expand Down Expand Up @@ -107,6 +108,33 @@ public void addAndSearchTest(String similarityFunction) {
});
}

@Test
public void propertiesTest() {

new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(ElasticsearchRestClientAutoConfiguration.class,
ElasticsearchVectorStoreAutoConfiguration.class, RestClientAutoConfiguration.class,
SpringAiRetryAutoConfiguration.class, OpenAiAutoConfiguration.class))
.withPropertyValues("spring.elasticsearch.uris=" + elasticsearchContainer.getHttpHostAddress(),
"spring.ai.openai.api-key=" + System.getenv("OPENAI_API_KEY"),
"spring.ai.vectorstore.elasticsearch.index-name=example",
"spring.ai.vectorstore.elasticsearch.dims=1024",
"spring.ai.vectorstore.elasticsearch.dense-vector-indexing=true",
"spring.ai.vectorstore.elasticsearch.similarity=dot_product")
.run(context -> {
var properties = context.getBean(ElasticsearchVectorStoreProperties.class);
var elasticsearchVectorStore = context.getBean(ElasticsearchVectorStore.class);

assertThat(properties).isNotNull();
assertThat(properties.getIndexName()).isEqualTo("example");
assertThat(properties.getDims()).isEqualTo(1024);
assertThat(properties.isDenseVectorIndexing()).isTrue();
assertThat(properties.getSimilarity()).isEqualTo("dot_product");

assertThat(elasticsearchVectorStore).isNotNull();
});
}

private String getText(String uri) {
var resource = new DefaultResourceLoader().getResource(uri);
try {
Expand Down
1 change: 0 additions & 1 deletion vector-stores/spring-ai-elasticsearch-store/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
<dependency>
<groupId>co.elastic.clients</groupId>
<artifactId>elasticsearch-java</artifactId>
<version>8.12.2</version>
</dependency>

<!-- TESTING -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
package org.springframework.ai.vectorstore;

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty;
import co.elastic.clients.elasticsearch._types.mapping.Property;
import co.elastic.clients.elasticsearch._types.query_dsl.Query;
import co.elastic.clients.elasticsearch.core.BulkRequest;
import co.elastic.clients.elasticsearch.core.BulkResponse;
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
import co.elastic.clients.elasticsearch.core.search.Hit;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
import co.elastic.clients.json.JsonData;
Expand All @@ -38,14 +41,14 @@
import org.springframework.util.Assert;

import java.io.IOException;
import java.io.StringReader;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* @author Jemin Huh
* @author Wei Jiang
* @since 1.0.0
*/
public class ElasticsearchVectorStore implements VectorStore, InitializingBean {
Expand All @@ -55,29 +58,28 @@ public class ElasticsearchVectorStore implements VectorStore, InitializingBean {

private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class);

private static final String INDEX_NAME = "spring-ai-document-index";

private final EmbeddingClient embeddingClient;

private final ElasticsearchClient elasticsearchClient;

private final String index;
private final ElasticsearchVectorStoreOptions options;

private final FilterExpressionConverter filterExpressionConverter;

private String similarityFunction;

public ElasticsearchVectorStore(RestClient restClient, EmbeddingClient embeddingClient) {
this(INDEX_NAME, restClient, embeddingClient);
this(new ElasticsearchVectorStoreOptions(), restClient, embeddingClient);
}

public ElasticsearchVectorStore(String index, RestClient restClient, EmbeddingClient embeddingClient) {
public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient,
EmbeddingClient embeddingClient) {
Objects.requireNonNull(embeddingClient, "RestClient must not be null");
Objects.requireNonNull(embeddingClient, "EmbeddingClient must not be null");
this.elasticsearchClient = new ElasticsearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(
new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false))));
this.embeddingClient = embeddingClient;
this.index = index;
this.options = options;
this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();
// the potential functions for vector fields at
// https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-script-score-query.html#vector-functions
Expand All @@ -92,22 +94,33 @@ public ElasticsearchVectorStore withSimilarityFunction(String similarityFunction
@Override
public void add(List<Document> documents) {
BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder();

for (Document document : documents) {
if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().isEmpty()) {
logger.debug("Calling EmbeddingClient for document id = " + document.getId());
document.setEmbedding(this.embeddingClient.embed(document));
}
builkRequestBuilder
.operations(op -> op.index(idx -> idx.index(this.index).id(document.getId()).document(document)));
builkRequestBuilder.operations(op -> op
.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(document)));
}

BulkResponse bulkRequest = bulkRequest(builkRequestBuilder.build());

if (bulkRequest.errors()) {
List<BulkResponseItem> bulkResponseItems = bulkRequest.items();
for (BulkResponseItem bulkResponseItem : bulkResponseItems) {
if (bulkResponseItem.error() != null) {
throw new IllegalStateException(bulkResponseItem.error().reason());
}
}
}
bulkRequest(builkRequestBuilder.build());
}

@Override
public Optional<Boolean> delete(List<String> idList) {
BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder();
for (String id : idList)
builkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.index).id(id)));
builkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id)));
return Optional.of(bulkRequest(builkRequestBuilder.build()).errors());
}

Expand All @@ -130,11 +143,12 @@ public List<Document> similaritySearch(SearchRequest searchRequest) {

public List<Document> similaritySearch(List<Double> embedding, int topK, double similarityThreshold,
Filter.Expression filterExpression) {
return similaritySearch(new co.elastic.clients.elasticsearch.core.SearchRequest.Builder()
.query(getElasticsearchSimilarityQuery(embedding, filterExpression))
.size(topK)
.minScore(similarityThreshold)
.build());
return similaritySearch(
new co.elastic.clients.elasticsearch.core.SearchRequest.Builder().index(options.getIndexName())
.query(getElasticsearchSimilarityQuery(embedding, filterExpression))
.size(topK)
.minScore(similarityThreshold)
.build());
}

private Query getElasticsearchSimilarityQuery(List<Double> embedding, Filter.Expression filterExpression) {
Expand Down Expand Up @@ -172,22 +186,32 @@ private Document toDocument(Hit<Document> hit) {
return document;
}

public boolean exists(String targetIndex) {
private boolean indexExists() {
try {
BooleanResponse response = this.elasticsearchClient.indices()
.exists(existRequestBuilder -> existRequestBuilder.index(targetIndex));
.exists(existRequestBuilder -> existRequestBuilder.index(options.getIndexName()));
return response.value();
}
catch (IOException e) {
throw new RuntimeException(e);
}
}

public CreateIndexResponse createIndexMapping(String index, String mappingJson) {
private CreateIndexResponse createIndexMapping() {
try {
return this.elasticsearchClient.indices()
.create(createIndexBuilder -> createIndexBuilder.index(index)
.mappings(typeMappingBuilder -> typeMappingBuilder.withJson(new StringReader(mappingJson))));
.create(createIndexBuilder -> createIndexBuilder.index(options.getIndexName())
.mappings(typeMappingBuilder -> {
typeMappingBuilder.properties("embedding",
new Property.Builder()
.denseVector(new DenseVectorProperty.Builder().dims(options.getDims())
.similarity(options.getSimilarity())
.index(options.isDenseVectorIndexing())
.build())
.build());

return typeMappingBuilder;
}));
}
catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -196,19 +220,8 @@ public CreateIndexResponse createIndexMapping(String index, String mappingJson)

@Override
public void afterPropertiesSet() {
if (!exists(this.index)) {
createIndexMapping(this.index, """
{
"properties": {
"embedding": {
"type": "dense_vector",
"dims": 1536,
"index": true,
"similarity": "cosine"
}
}
}
""");
if (!indexExists()) {
createIndexMapping();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.vectorstore;

/**
* Provided Elasticsearch vector option configuration.
* https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html
*
* @author Wei Jiang
* @since 1.0.0
*/
public class ElasticsearchVectorStoreOptions {

private String indexName = "spring-ai-document-index";

private int dims = 1536;

private boolean denseVectorIndexing = true;

private String similarity = "cosine";

public String getIndexName() {
return indexName;
}

public void setIndexName(String indexName) {
this.indexName = indexName;
}

public int getDims() {
return dims;
}

public void setDims(int dims) {
this.dims = dims;
}

public boolean isDenseVectorIndexing() {
return denseVectorIndexing;
}

public void setDenseVectorIndexing(boolean denseVectorIndexing) {
this.denseVectorIndexing = denseVectorIndexing;
}

public String getSimilarity() {
return similarity;
}

public void setSimilarity(String similarity) {
this.similarity = similarity;
}

}