Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import io.micrometer.observation.ObservationRegistry;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand All @@ -40,6 +42,8 @@ public abstract class AbstractVectorStoreBuilder<T extends AbstractVectorStoreBu
@Nullable
protected VectorStoreObservationConvention customObservationConvention;

protected BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

public AbstractVectorStoreBuilder(EmbeddingModel embeddingModel) {
Assert.notNull(embeddingModel, "EmbeddingModel must be configured");
this.embeddingModel = embeddingModel;
Expand All @@ -49,6 +53,10 @@ public EmbeddingModel getEmbeddingModel() {
return this.embeddingModel;
}

public BatchingStrategy getBatchingStrategy() {
return this.batchingStrategy;
}

public ObservationRegistry getObservationRegistry() {
return this.observationRegistry;
}
Expand Down Expand Up @@ -81,4 +89,15 @@ public T customObservationConvention(@Nullable VectorStoreObservationConvention
return self();
}

/**
* Sets the batching strategy.
* @param batchingStrategy the strategy to use
* @return the builder instance
*/
public T batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return self();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentWriter;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -108,6 +109,13 @@ interface Builder<T extends Builder<T>> {
*/
T customObservationConvention(VectorStoreObservationConvention convention);

/**
* Sets the batching strategy.
* @param batchingStrategy the strategy to use
* @return the builder instance for method chaining
*/
T batchingStrategy(BatchingStrategy batchingStrategy);

/**
* Builds and returns a new VectorStore instance with the configured settings.
* @return a new VectorStore instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.micrometer.observation.ObservationRegistry;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
Expand All @@ -47,11 +48,14 @@ public abstract class AbstractObservationVectorStore implements VectorStore {

protected final EmbeddingModel embeddingModel;

protected final BatchingStrategy batchingStrategy;

private AbstractObservationVectorStore(EmbeddingModel embeddingModel, ObservationRegistry observationRegistry,
@Nullable VectorStoreObservationConvention customObservationConvention) {
@Nullable VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) {
this.embeddingModel = embeddingModel;
this.observationRegistry = observationRegistry;
this.customObservationConvention = customObservationConvention;
this.batchingStrategy = batchingStrategy;
}

/**
Expand All @@ -60,7 +64,8 @@ private AbstractObservationVectorStore(EmbeddingModel embeddingModel, Observatio
* @param builder the builder containing configuration settings
*/
public AbstractObservationVectorStore(AbstractVectorStoreBuilder<?> builder) {
this(builder.getEmbeddingModel(), builder.getObservationRegistry(), builder.getCustomObservationConvention());
this(builder.getEmbeddingModel(), builder.getObservationRegistry(), builder.getCustomObservationConvention(),
builder.getBatchingStrategy());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,20 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.micrometer.observation.ObservationRegistry;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -98,8 +94,6 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen

private final List<String> metadataFieldsList;

private final BatchingStrategy batchingStrategy;

private CosmosAsyncContainer container;

/**
Expand All @@ -122,7 +116,6 @@ protected CosmosDBVectorStore(Builder builder) {
this.vectorStoreThroughput = builder.vectorStoreThroughput;
this.vectorDimensions = builder.vectorDimensions;
this.metadataFieldsList = builder.metadataFieldsList;
this.batchingStrategy = builder.batchingStrategy;

cosmosClient.createDatabaseIfNotExists(databaseName).block();
initializeContainer(containerName, databaseName, vectorStoreThroughput, vectorDimensions, partitionKeyPath);
Expand Down Expand Up @@ -404,8 +397,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {

private List<String> metadataFieldsList = new ArrayList<>();

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private Builder(CosmosAsyncClient cosmosClient, EmbeddingModel embeddingModel) {
super(embeddingModel);
Assert.notNull(cosmosClient, "CosmosClient must not be null");
Expand Down Expand Up @@ -483,18 +474,6 @@ public Builder metadataFields(List<String> metadataFieldsList) {
return this;
}

/**
* Sets the batching strategy.
* @param batchingStrategy the strategy to use
* @return the builder instance
* @throws IllegalArgumentException if batchingStrategy is null
*/
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return this;
}

@Override
public CosmosDBVectorStore build() {
return new CosmosDBVectorStore(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@
import com.azure.search.documents.models.SearchOptions;
import com.azure.search.documents.models.VectorSearchOptions;
import com.azure.search.documents.models.VectorizedQuery;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
Expand All @@ -60,7 +57,6 @@
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -110,8 +106,6 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements

private final boolean initializeSchema;

private final BatchingStrategy batchingStrategy;

/**
* List of metadata fields (as field name and type) that can be used in similarity
* search query filter expressions. The {@link Document#getMetadata()} can contain
Expand Down Expand Up @@ -146,7 +140,6 @@ protected AzureVectorStore(Builder builder) {
this.searchIndexClient = builder.searchIndexClient;
this.initializeSchema = builder.initializeSchema;
this.filterMetadataFields = builder.filterMetadataFields;
this.batchingStrategy = builder.batchingStrategy;
this.defaultTopK = builder.defaultTopK;
this.defaultSimilarityThreshold = builder.defaultSimilarityThreshold;
this.indexName = builder.indexName;
Expand Down Expand Up @@ -389,8 +382,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {

private List<MetadataField> filterMetadataFields = List.of();

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private int defaultTopK = DEFAULT_TOP_K;

private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD;
Expand Down Expand Up @@ -423,17 +414,6 @@ public Builder filterMetadataFields(List<MetadataField> filterMetadataFields) {
return this;
}

/**
* Sets the batching strategy.
* @param batchingStrategy the strategy to use
* @return the builder instance
*/
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return this;
}

/**
* Sets the index name for the Azure Vector Store.
* @param indexName the name of the index to use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme

private final boolean closeSessionOnClose;

private final BatchingStrategy batchingStrategy;

private final ConcurrentMap<Set<String>, PreparedStatement> addStmts = new ConcurrentHashMap<>();

private final PreparedStatement deleteStmt;
Expand All @@ -239,7 +237,6 @@ protected CassandraVectorStore(Builder builder) {
this.primaryKeyTranslator = builder.primaryKeyTranslator;
this.executor = Executors.newFixedThreadPool(builder.fixedThreadPoolExecutorSize);
this.closeSessionOnClose = builder.closeSessionOnClose;
this.batchingStrategy = builder.batchingStrategy;

ensureSchemaExists(embeddingModel.dimensions());
prepareAddStatement(Set.of());
Expand Down Expand Up @@ -777,8 +774,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {

private int fixedThreadPoolExecutorSize = DEFAULT_ADD_CONCURRENCY;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private FilterExpressionConverter filterExpressionConverter;

private DocumentIdTranslator documentIdTranslator = (String id) -> List.of(id);
Expand Down Expand Up @@ -917,18 +912,6 @@ public Builder disallowSchemaChanges(boolean disallowSchemaChanges) {
return this;
}

/**
* Sets the batching strategy.
* @param batchingStrategy the batching strategy to use
* @return the builder instance
* @throws IllegalArgumentException if batchingStrategy is null
*/
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return this;
}

/**
* Sets the filter expression converter.
* @param converter the filter expression converter to use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements

private final boolean initializeSchema;

private final BatchingStrategy batchingStrategy;

private final ObjectMapper objectMapper;

private boolean initialized = false;
Expand All @@ -95,7 +93,6 @@ protected ChromaVectorStore(Builder builder) {
this.collectionName = builder.collectionName;
this.initializeSchema = builder.initializeSchema;
this.filterExpressionConverter = builder.filterExpressionConverter;
this.batchingStrategy = builder.batchingStrategy;
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();

if (builder.initializeImmediately) {
Expand Down Expand Up @@ -232,8 +229,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {

private boolean initializeSchema = false;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private FilterExpressionConverter filterExpressionConverter = new ChromaFilterExpressionConverter();

private boolean initializeImmediately = false;
Expand Down Expand Up @@ -266,18 +261,6 @@ public Builder initializeSchema(boolean initializeSchema) {
return this;
}

/**
* Sets the batching strategy.
* @param batchingStrategy the batching strategy to use
* @return the builder instance
* @throws IllegalArgumentException if batchingStrategy is null
*/
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "batchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return this;
}

/**
* Sets the filter expression converter.
* @param converter the filter expression converter to use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp

private final boolean initializeSchema;

private final BatchingStrategy batchingStrategy;

protected ElasticsearchVectorStore(Builder builder) {
super(builder);

Expand All @@ -174,7 +172,6 @@ protected ElasticsearchVectorStore(Builder builder) {
this.initializeSchema = builder.initializeSchema;
this.options = builder.options;
this.filterExpressionConverter = builder.filterExpressionConverter;
this.batchingStrategy = builder.batchingStrategy;

String version = Version.VERSION == null ? "Unknown" : Version.VERSION.toString();
this.elasticsearchClient = new ElasticsearchClient(new RestClientTransport(builder.restClient,
Expand Down Expand Up @@ -371,8 +368,6 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {

private boolean initializeSchema = false;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private FilterExpressionConverter filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();

/**
Expand Down Expand Up @@ -408,18 +403,6 @@ public Builder initializeSchema(boolean initializeSchema) {
return this;
}

/**
* Sets the batching strategy for vector operations.
* @param batchingStrategy the batching strategy to use
* @return the builder instance
* @throws IllegalArgumentException if batchingStrategy is null
*/
public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "batchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return this;
}

/**
* Sets the filter expression converter.
* @param converter the filter expression converter to use
Expand Down
Loading
Loading