-
Notifications
You must be signed in to change notification settings - Fork 2k
Enhancing Elasticsearch vector store implementation #592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
86650ac
knn instead of script_score, removed initialization
l-trotta 8180189
only using normalized similarities, adjusted unit test
l-trotta a73509b
import clean
l-trotta 0ab9981
making l2norm's distances consistent with others
l-trotta 3e31e00
refactor unit test
l-trotta af5d8c1
rebase
l-trotta 5acdfec
format
l-trotta db32740
dependency version, docs
l-trotta 8e08cb0
rebase
l-trotta ca9fc88
autoconfigure test fix
l-trotta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,17 +16,13 @@ | |
package org.springframework.ai.vectorstore; | ||
|
||
import co.elastic.clients.elasticsearch.ElasticsearchClient; | ||
import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty; | ||
import co.elastic.clients.elasticsearch._types.mapping.Property; | ||
import co.elastic.clients.elasticsearch._types.query_dsl.Query; | ||
import co.elastic.clients.elasticsearch.core.BulkRequest; | ||
import co.elastic.clients.elasticsearch.core.BulkResponse; | ||
import co.elastic.clients.elasticsearch.core.SearchResponse; | ||
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; | ||
import co.elastic.clients.elasticsearch.core.search.Hit; | ||
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse; | ||
import co.elastic.clients.json.JsonData; | ||
import co.elastic.clients.json.jackson.JacksonJsonpMapper; | ||
import co.elastic.clients.transport.endpoints.BooleanResponse; | ||
import co.elastic.clients.transport.rest_client.RestClientTransport; | ||
import com.fasterxml.jackson.databind.DeserializationFeature; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
|
@@ -46,16 +42,17 @@ | |
import java.util.Optional; | ||
import java.util.stream.Collectors; | ||
|
||
import static java.lang.Math.sqrt; | ||
import static org.springframework.ai.vectorstore.SimilarityFunction.l2_norm; | ||
|
||
/** | ||
* @author Jemin Huh | ||
* @author Wei Jiang | ||
* @author Laura Trotta | ||
* @since 1.0.0 | ||
*/ | ||
public class ElasticsearchVectorStore implements VectorStore, InitializingBean { | ||
|
||
// divided by 2 to get score in the range [0, 1] | ||
public static final String COSINE_SIMILARITY_FUNCTION = "(cosineSimilarity(params.query_vector, 'embedding') + 1.0) / 2"; | ||
|
||
private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class); | ||
|
||
private final EmbeddingModel embeddingModel; | ||
|
@@ -66,8 +63,6 @@ public class ElasticsearchVectorStore implements VectorStore, InitializingBean { | |
|
||
private final FilterExpressionConverter filterExpressionConverter; | ||
|
||
private String similarityFunction; | ||
|
||
private final boolean initializeSchema; | ||
|
||
public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) { | ||
|
@@ -84,30 +79,22 @@ public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestCli | |
this.embeddingModel = embeddingModel; | ||
this.options = options; | ||
this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter(); | ||
// the potential functions for vector fields at | ||
// https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-script-score-query.html#vector-functions | ||
this.similarityFunction = COSINE_SIMILARITY_FUNCTION; | ||
} | ||
|
||
public ElasticsearchVectorStore withSimilarityFunction(String similarityFunction) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it required to remove the ability to change the similarity function at runtime? |
||
this.similarityFunction = similarityFunction; | ||
return this; | ||
} | ||
|
||
@Override | ||
public void add(List<Document> documents) { | ||
BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder(); | ||
BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); | ||
|
||
for (Document document : documents) { | ||
if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().isEmpty()) { | ||
logger.debug("Calling EmbeddingModel for document id = " + document.getId()); | ||
document.setEmbedding(this.embeddingModel.embed(document)); | ||
} | ||
builkRequestBuilder.operations(op -> op | ||
bulkRequestBuilder.operations(op -> op | ||
.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(document))); | ||
} | ||
|
||
BulkResponse bulkRequest = bulkRequest(builkRequestBuilder.build()); | ||
BulkResponse bulkRequest = bulkRequest(bulkRequestBuilder.build()); | ||
|
||
if (bulkRequest.errors()) { | ||
List<BulkResponseItem> bulkResponseItems = bulkRequest.items(); | ||
|
@@ -121,10 +108,10 @@ public void add(List<Document> documents) { | |
|
||
@Override | ||
public Optional<Boolean> delete(List<String> idList) { | ||
BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder(); | ||
BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); | ||
for (String id : idList) | ||
builkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id))); | ||
return Optional.of(bulkRequest(builkRequestBuilder.build()).errors()); | ||
bulkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id))); | ||
return Optional.of(bulkRequest(bulkRequestBuilder.build()).errors()); | ||
} | ||
|
||
private BulkResponse bulkRequest(BulkRequest bulkRequest) { | ||
|
@@ -139,61 +126,67 @@ private BulkResponse bulkRequest(BulkRequest bulkRequest) { | |
@Override | ||
public List<Document> similaritySearch(SearchRequest searchRequest) { | ||
Assert.notNull(searchRequest, "The search request must not be null."); | ||
return similaritySearch(this.embeddingModel.embed(searchRequest.getQuery()), searchRequest.getTopK(), | ||
Double.valueOf(searchRequest.getSimilarityThreshold()).floatValue(), | ||
searchRequest.getFilterExpression()); | ||
} | ||
|
||
public List<Document> similaritySearch(List<Double> embedding, int topK, double similarityThreshold, | ||
Filter.Expression filterExpression) { | ||
return similaritySearch( | ||
new co.elastic.clients.elasticsearch.core.SearchRequest.Builder().index(options.getIndexName()) | ||
.query(getElasticsearchSimilarityQuery(embedding, filterExpression)) | ||
.size(topK) | ||
.minScore(similarityThreshold) | ||
.build()); | ||
} | ||
|
||
private Query getElasticsearchSimilarityQuery(List<Double> embedding, Filter.Expression filterExpression) { | ||
return Query.of(queryBuilder -> queryBuilder.scriptScore(scriptScoreQueryBuilder -> scriptScoreQueryBuilder | ||
.query(queryBuilder2 -> queryBuilder2.queryString(queryStringQuerybuilder -> queryStringQuerybuilder | ||
.query(getElasticsearchQueryString(filterExpression)))) | ||
.script(scriptBuilder -> scriptBuilder | ||
.inline(inlineScriptBuilder -> inlineScriptBuilder.source(this.similarityFunction) | ||
.params("query_vector", JsonData.of(embedding)))))); | ||
} | ||
|
||
private String getElasticsearchQueryString(Filter.Expression filterExpression) { | ||
return Objects.isNull(filterExpression) ? "*" | ||
: this.filterExpressionConverter.convertExpression(filterExpression); | ||
|
||
} | ||
|
||
private List<Document> similaritySearch(co.elastic.clients.elasticsearch.core.SearchRequest searchRequest) { | ||
try { | ||
return this.elasticsearchClient.search(searchRequest, Document.class) | ||
.hits() | ||
.hits() | ||
float threshold = (float) searchRequest.getSimilarityThreshold(); | ||
// reverting l2_norm distance to its original value | ||
if (options.getSimilarity().equals(l2_norm)) { | ||
threshold = 1 - threshold; | ||
} | ||
final float finalThreshold = threshold; | ||
List<Float> vectors = this.embeddingModel.embed(searchRequest.getQuery()) | ||
.stream() | ||
.map(this::toDocument) | ||
.collect(Collectors.toList()); | ||
.map(Double::floatValue) | ||
.toList(); | ||
|
||
SearchResponse<Document> res = elasticsearchClient.search( | ||
sr -> sr.index(options.getIndexName()) | ||
.knn(knn -> knn.queryVector(vectors) | ||
.similarity(finalThreshold) | ||
.k((long) searchRequest.getTopK()) | ||
.field("embedding") | ||
.numCandidates((long) (1.5 * searchRequest.getTopK())) | ||
.filter(fl -> fl.queryString( | ||
qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))), | ||
Document.class); | ||
|
||
return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList()); | ||
} | ||
catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private String getElasticsearchQueryString(Filter.Expression filterExpression) { | ||
return Objects.isNull(filterExpression) ? "*" | ||
: this.filterExpressionConverter.convertExpression(filterExpression); | ||
|
||
} | ||
|
||
private Document toDocument(Hit<Document> hit) { | ||
Document document = hit.source(); | ||
document.getMetadata().put("distance", 1 - hit.score().floatValue()); | ||
document.getMetadata().put("distance", calculateDistance(hit.score().floatValue())); | ||
return document; | ||
} | ||
|
||
private boolean indexExists() { | ||
// more info on score/distance calculation | ||
// https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search | ||
private float calculateDistance(Float score) { | ||
switch (options.getSimilarity()) { | ||
case l2_norm: | ||
// the returned value of l2_norm is the opposite of the other functions | ||
// (closest to zero means more accurate), so to make it consistent | ||
// with the other functions the reverse is returned applying a "1-" | ||
// to the standard transformation | ||
return (float) (1 - (sqrt((1 / score) - 1))); | ||
// cosine and dot_product | ||
default: | ||
return (2 * score) - 1; | ||
} | ||
} | ||
|
||
public boolean indexExists() { | ||
try { | ||
BooleanResponse response = this.elasticsearchClient.indices() | ||
.exists(existRequestBuilder -> existRequestBuilder.index(options.getIndexName())); | ||
return response.value(); | ||
return this.elasticsearchClient.indices().exists(ex -> ex.index(options.getIndexName())).value(); | ||
} | ||
catch (IOException e) { | ||
throw new RuntimeException(e); | ||
|
@@ -203,18 +196,9 @@ private boolean indexExists() { | |
private CreateIndexResponse createIndexMapping() { | ||
try { | ||
return this.elasticsearchClient.indices() | ||
.create(createIndexBuilder -> createIndexBuilder.index(options.getIndexName()) | ||
.mappings(typeMappingBuilder -> { | ||
typeMappingBuilder.properties("embedding", | ||
new Property.Builder() | ||
.denseVector(new DenseVectorProperty.Builder().dims(options.getDimensions()) | ||
.similarity(options.getSimilarity()) | ||
.index(options.isDenseVectorIndexing()) | ||
.build()) | ||
.build()); | ||
|
||
return typeMappingBuilder; | ||
})); | ||
.create(cr -> cr.index(options.getIndexName()) | ||
.mappings(map -> map.properties("embedding", p -> p.denseVector( | ||
dv -> dv.similarity(options.getSimilarity().toString()).dims(options.getDimensions()))))); | ||
} | ||
catch (IOException e) { | ||
throw new RuntimeException(e); | ||
|
@@ -233,4 +217,4 @@ public void afterPropertiesSet() { | |
} | ||
} | ||
|
||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason to add the elasticsearch-java explicitly here? As it is already defined in the vector-store dependenies?