diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java index 51349f1a345..14b3324853b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.chroma; import org.springframework.ai.chroma.ChromaApi; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.ChromaVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; @@ -36,6 +39,7 @@ /** * @author Christian Tzolov * @author Eddú Meléndez + * @author Soby Chacko */ @AutoConfiguration @ConditionalOnClass({ EmbeddingModel.class, RestClient.class, ChromaVectorStore.class, ObjectMapper.class }) @@ -73,14 +77,21 @@ else if (StringUtils.hasText(apiProperties.getUsername()) && StringUtils.hasText return chromaApi; } + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy chromaBatchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public ChromaVectorStore vectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, ChromaVectorStoreProperties storeProperties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy chromaBatchingStrategy) { return new ChromaVectorStore(embeddingModel, chromaApi, storeProperties.getCollectionName(), storeProperties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), chromaBatchingStrategy); } static class PropertiesChromaConnectionDetails implements ChromaConnectionDetails { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java index b943a676c4c..115773ba7d1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/elasticsearch/ElasticsearchVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.elasticsearch; import org.elasticsearch.client.RestClient; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.ElasticsearchVectorStore; import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; @@ -37,19 +40,26 @@ * @author Wei Jiang * @author Josh Long * @author Christian Tzolov + * @author Soby Chacko * @since 1.0.0 */ - @AutoConfiguration(after = ElasticsearchRestClientAutoConfiguration.class) @ConditionalOnClass({ ElasticsearchVectorStore.class, EmbeddingModel.class, RestClient.class }) @EnableConfigurationProperties(ElasticsearchVectorStoreProperties.class) class ElasticsearchVectorStoreAutoConfiguration { + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properties, RestClient restClient, EmbeddingModel embeddingModel, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { ElasticsearchVectorStoreOptions elasticsearchVectorStoreOptions = new ElasticsearchVectorStoreOptions(); if (StringUtils.hasText(properties.getIndexName())) { @@ -64,7 +74,7 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti return new ElasticsearchVectorStore(elasticsearchVectorStoreOptions, restClient, embeddingModel, properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java index 3b058636b73..3faaa3b644a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.neo4j; import org.neo4j.driver.Driver; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.Neo4jVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; @@ -34,17 +37,25 @@ * @author Jingzhou Ou * @author Josh Long * @author Christian Tzolov + * @author Soby Chacko */ @AutoConfiguration(after = Neo4jAutoConfiguration.class) @ConditionalOnClass({ Neo4jVectorStore.class, EmbeddingModel.class, Driver.class }) @EnableConfigurationProperties({ Neo4jVectorStoreProperties.class }) public class Neo4jVectorStoreAutoConfiguration { + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public Neo4jVectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVectorStoreProperties properties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { Neo4jVectorStore.Neo4jVectorStoreConfig config = Neo4jVectorStore.Neo4jVectorStoreConfig.builder() .withDatabaseName(properties.getDatabaseName()) .withEmbeddingDimension(properties.getEmbeddingDimension()) @@ -58,7 +69,7 @@ public Neo4jVectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel return new Neo4jVectorStore(driver, embeddingModel, config, properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java index 489842a8e10..d1cb2379070 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/qdrant/QdrantVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.qdrant; import io.micrometer.observation.ObservationRegistry; import io.qdrant.client.QdrantClient; import io.qdrant.client.QdrantGrpcClient; + +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore; import org.springframework.beans.factory.ObjectProvider; @@ -32,6 +36,7 @@ * @author Anush Shetty * @author Eddú Meléndez * @author Christian Tzolov + * @author Soby Chacko * @since 0.8.1 */ @AutoConfiguration @@ -58,14 +63,21 @@ public QdrantClient qdrantClient(QdrantVectorStoreProperties properties, return new QdrantClient(grpcClientBuilder.build()); } + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public QdrantVectorStore vectorStore(EmbeddingModel embeddingModel, QdrantVectorStoreProperties properties, QdrantClient qdrantClient, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { return new QdrantVectorStore(qdrantClient, properties.getCollectionName(), embeddingModel, properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } static class PropertiesQdrantConnectionDetails implements QdrantConnectionDetails { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java index 20e23bf1ad4..92631831b60 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.redis; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; @@ -35,6 +38,7 @@ /** * @author Christian Tzolov * @author Eddú Meléndez + * @author Soby Chacko */ @AutoConfiguration(after = RedisAutoConfiguration.class) @ConditionalOnClass({ JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class }) @@ -42,11 +46,18 @@ @EnableConfigurationProperties(RedisVectorStoreProperties.class) public class RedisVectorStoreAutoConfiguration { + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties, JedisConnectionFactory jedisConnectionFactory, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { var config = RedisVectorStoreConfig.builder() .withIndexName(properties.getIndex()) @@ -56,7 +67,7 @@ public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorSt return new RedisVectorStore(config, embeddingModel, new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java index 48a77102724..de6e9a49033 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/typesense/TypesenseVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.typesense; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.TypesenseVectorStore; import org.springframework.ai.vectorstore.TypesenseVectorStore.TypesenseVectorStoreConfig; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; @@ -38,6 +41,7 @@ /** * @author Pablo Sanchidrian Herrera * @author Eddú Meléndez + * @author Soby Chacko */ @AutoConfiguration @ConditionalOnClass({ TypesenseVectorStore.class, EmbeddingModel.class }) @@ -51,11 +55,18 @@ TypesenseVectorStoreAutoConfiguration.PropertiesTypesenseConnectionDetails types return new TypesenseVectorStoreAutoConfiguration.PropertiesTypesenseConnectionDetails(properties); } + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public TypesenseVectorStore vectorStore(Client typesenseClient, EmbeddingModel embeddingModel, TypesenseVectorStoreProperties properties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { TypesenseVectorStoreConfig config = TypesenseVectorStoreConfig.builder() .withCollectionName(properties.getCollectionName()) @@ -64,7 +75,7 @@ public TypesenseVectorStore vectorStore(Client typesenseClient, EmbeddingModel e return new TypesenseVectorStore(typesenseClient, embeddingModel, config, properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java index 4c817ee58ce..9ca3899db74 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.weaviate; import io.micrometer.observation.ObservationRegistry; @@ -20,7 +21,10 @@ import io.weaviate.client.WeaviateAuthClient; import io.weaviate.client.WeaviateClient; import io.weaviate.client.v1.auth.exception.AuthException; + +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.WeaviateVectorStore; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; @@ -62,11 +66,18 @@ public WeaviateClient weaviateClient(WeaviateVectorStoreProperties properties, } } + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, WeaviateClient weaviateClient, WeaviateVectorStoreProperties properties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { WeaviateVectorStoreConfig.Builder configBuilder = WeaviateVectorStore.WeaviateVectorStoreConfig.builder() .withObjectClass(properties.getObjectClass()) @@ -79,7 +90,7 @@ public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, WeaviateCl return new WeaviateVectorStore(configBuilder.build(), embeddingModel, weaviateClient, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } static class PropertiesWeaviateConnectionDetails implements WeaviateConnectionDetails { diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 321b75a684c..a7f9f4782ca 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.ArrayList; @@ -26,7 +27,10 @@ import org.springframework.ai.chroma.ChromaApi.DeleteEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.Embedding; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; @@ -69,18 +73,21 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, boolean initializeSchema) { this(embeddingModel, chromaApi, DEFAULT_COLLECTION_NAME, initializeSchema); } public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, String collectionName, boolean initializeSchema) { - this(embeddingModel, chromaApi, collectionName, initializeSchema, ObservationRegistry.NOOP, null); + this(embeddingModel, chromaApi, collectionName, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, String collectionName, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -89,6 +96,7 @@ public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, Str this.collectionName = collectionName; this.initializeSchema = initializeSchema; this.filterExpressionConverter = new ChromaFilterExpressionConverter(); + this.batchingStrategy = batchingStrategy; } public void setFilterExpressionConverter(FilterExpressionConverter filterExpressionConverter) { @@ -108,11 +116,13 @@ public void doAdd(List documents) { List contents = new ArrayList<>(); List embeddings = new ArrayList<>(); + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + for (Document document : documents) { ids.add(document.getId()); metadatas.add(document.getMetadata()); contents.add(document.getContent()); - document.setEmbedding(this.embeddingModel.embed(document)); + document.setEmbedding(document.getEmbedding()); embeddings.add(document.getEmbedding()); } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java index d0bae901919..120485514cf 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java @@ -26,6 +26,7 @@ import org.springframework.ai.chroma.ChromaApi; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -169,7 +170,8 @@ public ChromaApi chromaApi(RestClient.Builder builder) { @Bean public VectorStore chromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, ObservationRegistry observationRegistry) { - return new ChromaVectorStore(embeddingModel, chromaApi, "TestCollection", true, observationRegistry, null); + return new ChromaVectorStore(embeddingModel, chromaApi, "TestCollection", true, observationRegistry, null, + new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 56db121a85f..cb185bc6d16 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -28,7 +28,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -85,18 +88,21 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) { this(new ElasticsearchVectorStoreOptions(), restClient, embeddingModel, initializeSchema); } public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) { - this(options, restClient, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null); + this(options, restClient, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -108,17 +114,16 @@ public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestCli this.embeddingModel = embeddingModel; this.options = options; this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter(); + this.batchingStrategy = batchingStrategy; } @Override public void doAdd(List documents) { BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + for (Document document : documents) { - if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().length == 0) { - logger.debug("Calling EmbeddingModel for document id = " + document.getId()); - document.setEmbedding(this.embeddingModel.embed(document)); - } // We call operations on BulkRequest.Builder only if the index exists. // For the index to be present, either it must be pre-created or set the // initializeSchema to true. diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java index 254e903fe3d..e6bb89b9960 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java @@ -33,6 +33,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -202,7 +203,7 @@ public TestObservationRegistry observationRegistry() { public ElasticsearchVectorStore vectorStoreDefault(EmbeddingModel embeddingModel, RestClient restClient, ObservationRegistry observationRegistry) { return new ElasticsearchVectorStore(new ElasticsearchVectorStoreOptions(), restClient, embeddingModel, true, - observationRegistry, null); + observationRegistry, null, new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java index 49be7c23a01..55c169d242f 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.HashMap; @@ -26,7 +27,10 @@ import org.neo4j.driver.SessionConfig; import org.neo4j.driver.Values; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.filter.Neo4jVectorFilterExpressionConverter; @@ -43,6 +47,7 @@ * @author Michael Simons * @author Christian Tzolov * @author Thomas Vitale + * @author Soby Chacko */ public class Neo4jVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -284,31 +289,33 @@ public Neo4jVectorStoreConfig build() { private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVectorStoreConfig config, boolean initializeSchema) { - this(driver, embeddingModel, config, initializeSchema, ObservationRegistry.NOOP, null); + this(driver, embeddingModel, config, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } public Neo4jVectorStore(Driver driver, EmbeddingModel embeddingModel, Neo4jVectorStoreConfig config, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { - + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); this.initializeSchema = initializeSchema; - Assert.notNull(driver, "Neo4j driver must not be null"); Assert.notNull(embeddingModel, "Embedding model must not be null"); - this.driver = driver; this.embeddingModel = embeddingModel; - this.config = config; + this.batchingStrategy = batchingStrategy; } @Override public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + var rows = documents.stream().map(this::documentToRecord).toList(); try (var session = this.driver.session()) { @@ -398,8 +405,7 @@ public void afterPropertiesSet() { } private Map documentToRecord(Document document) { - var embedding = this.embeddingModel.embed(document); - document.setEmbedding(embedding); + document.setEmbedding(document.getEmbedding()); var row = new HashMap(); @@ -411,7 +417,7 @@ private Map documentToRecord(Document document) { document.getMetadata().forEach((k, v) -> properties.put("metadata." + k, Values.value(v))); row.put("properties", properties); - row.put(this.config.embeddingProperty, Values.value(embedding)); + row.put(this.config.embeddingProperty, Values.value(document.getEmbedding())); return row; } diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java index 540fc116ba8..4f0b7bcd748 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreObservationIT.java @@ -31,6 +31,7 @@ import org.neo4j.driver.GraphDatabase; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -174,7 +175,7 @@ public VectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel, ObservationRegistry observationRegistry) { return new Neo4jVectorStore(driver, embeddingModel, Neo4jVectorStore.Neo4jVectorStoreConfig.defaultConfig(), - true, observationRegistry, null); + true, observationRegistry, null, new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java index 4f45ece4c11..158caf90352 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.qdrant; import static io.qdrant.client.PointIdFactory.id; @@ -27,7 +28,10 @@ import java.util.concurrent.ExecutionException; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -58,6 +62,7 @@ * @author Christian Tzolov * @author Eddú Meléndez * @author Josh Long + * @author Soby Chacko * @since 0.8.1 */ public class QdrantVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -78,6 +83,8 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + /** * Configuration class for the QdrantVectorStore. * @@ -161,7 +168,8 @@ public QdrantVectorStore(QdrantClient qdrantClient, QdrantVectorStoreConfig conf */ public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, EmbeddingModel embeddingModel, boolean initializeSchema) { - this(qdrantClient, collectionName, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null); + this(qdrantClient, collectionName, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } /** @@ -175,7 +183,7 @@ public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, Embed */ public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, EmbeddingModel embeddingModel, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -187,6 +195,7 @@ public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, Embed this.embeddingModel = embeddingModel; this.collectionName = collectionName; this.qdrantClient = qdrantClient; + this.batchingStrategy = batchingStrategy; } /** @@ -196,16 +205,17 @@ public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, Embed @Override public void doAdd(List documents) { try { - List points = documents.stream().map(document -> { - // Compute and assign an embedding to the document. - document.setEmbedding(this.embeddingModel.embed(document)); - return PointStruct.newBuilder() + // Compute and assign an embedding to the document. + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + + List points = documents.stream() + .map(document -> PointStruct.newBuilder() .setId(id(UUID.fromString(document.getId()))) .setVectors(vectors(document.getEmbedding())) .putAllPayload(toPayload(document)) - .build(); - }).toList(); + .build()) + .toList(); this.qdrantClient.upsertAsync(this.collectionName, points).get(); } diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java index be5e190525d..0dbf6765708 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreObservationIT.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.mistralai.MistralAiEmbeddingModel; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.observation.conventions.SpringAiKind; @@ -191,8 +192,8 @@ public QdrantClient qdrantClient() { @Bean public VectorStore qdrantVectorStore(EmbeddingModel embeddingModel, QdrantClient qdrantClient, ObservationRegistry observationRegistry) { - return new QdrantVectorStore(qdrantClient, COLLECTION_NAME, embeddingModel, true, observationRegistry, - null); + return new QdrantVectorStore(qdrantClient, COLLECTION_NAME, embeddingModel, true, observationRegistry, null, + new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java index de72bcc9f51..e7e584fc80c 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.text.MessageFormat; @@ -29,7 +30,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; @@ -75,6 +79,7 @@ * @author Christian Tzolov * @author Eddú Meléndez * @author Thomas Vitale + * @author Soby Chacko * @see VectorStore * @see RedisVectorStoreConfig * @see EmbeddingModel @@ -278,15 +283,18 @@ public RedisVectorStoreConfig build() { private FilterExpressionConverter filterExpressionConverter; + private final BatchingStrategy batchingStrategy; + public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema) { - this(config, embeddingModel, jedis, initializeSchema, ObservationRegistry.NOOP, null); + this(config, embeddingModel, jedis, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -298,6 +306,7 @@ public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingM this.embeddingModel = embeddingModel; this.config = config; this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields); + this.batchingStrategy = batchingStrategy; } public JedisPooled getJedis() { @@ -307,12 +316,13 @@ public JedisPooled getJedis() { @Override public void doAdd(List documents) { try (Pipeline pipeline = this.jedis.pipelined()) { - for (Document document : documents) { - var embedding = this.embeddingModel.embed(document); - document.setEmbedding(embedding); + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + + for (Document document : documents) { + document.setEmbedding(document.getEmbedding()); var fields = new HashMap(); - fields.put(this.config.embeddingFieldName, embedding); + fields.put(this.config.embeddingFieldName, document.getEmbedding()); fields.put(this.config.contentFieldName, document.getContent()); fields.putAll(document.getMetadata()); pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields); diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java index fb2d018cc81..2a7b717bfd6 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java @@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -178,7 +179,7 @@ public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, .build(), embeddingModel, new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), true, - observationRegistry, null); + observationRegistry, null, new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java index 8dda34eaf29..c27161a0f84 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java @@ -26,7 +26,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; @@ -85,6 +88,8 @@ public class TypesenseVectorStore extends AbstractObservationVectorStore impleme private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + public static class TypesenseVectorStoreConfig { private final String collectionName; @@ -162,12 +167,13 @@ public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel) { public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel, TypesenseVectorStoreConfig config, boolean initializeSchema) { - this(client, embeddingModel, config, initializeSchema, ObservationRegistry.NOOP, null); + this(client, embeddingModel, config, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel, TypesenseVectorStoreConfig config, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -178,19 +184,21 @@ public TypesenseVectorStore(Client client, EmbeddingModel embeddingModel, Typese this.embeddingModel = embeddingModel; this.config = config; this.initializeSchema = initializeSchema; + this.batchingStrategy = batchingStrategy; } @Override public void doAdd(List documents) { Assert.notNull(documents, "Documents must not be null"); + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List> documentList = documents.stream().map(document -> { HashMap typesenseDoc = new HashMap<>(); typesenseDoc.put(DOC_ID_FIELD_NAME, document.getId()); typesenseDoc.put(CONTENT_FIELD_NAME, document.getContent()); typesenseDoc.put(METADATA_FIELD_NAME, document.getMetadata()); - float[] embedding = this.embeddingModel.embed(document.getContent()); - typesenseDoc.put(EMBEDDING_FIELD_NAME, embedding); + typesenseDoc.put(EMBEDDING_FIELD_NAME, document.getEmbedding()); return typesenseDoc; }).toList(); diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java index 1a757646c3f..47d35b40d8a 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreObservationIT.java @@ -27,6 +27,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -169,7 +170,8 @@ public VectorStore vectorStore(Client client, EmbeddingModel embeddingModel, .withEmbeddingDimension(embeddingModel.dimensions()) .build(); - return new TypesenseVectorStore(client, embeddingModel, config, true, observationRegistry, null); + return new TypesenseVectorStore(client, embeddingModel, config, true, observationRegistry, null, + new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java index 7031b36592b..0c381b4e754 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.ArrayList; @@ -24,7 +25,10 @@ import java.util.stream.Collectors; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; @@ -96,6 +100,8 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore { private final String weaviateObjectClass; + private final BatchingStrategy batchingStrategy; + /** * List of metadata fields (as field name and type) that can be used in similarity * search query filter expressions. The {@link Document#getMetadata()} can contain @@ -290,7 +296,8 @@ public WeaviateVectorStoreConfig build() { */ public WeaviateVectorStore(WeaviateVectorStoreConfig vectorStoreConfig, EmbeddingModel embeddingModel, WeaviateClient weaviateClient) { - this(vectorStoreConfig, embeddingModel, weaviateClient, ObservationRegistry.NOOP, null); + this(vectorStoreConfig, embeddingModel, weaviateClient, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } /** @@ -303,7 +310,7 @@ public WeaviateVectorStore(WeaviateVectorStoreConfig vectorStoreConfig, Embeddin */ public WeaviateVectorStore(WeaviateVectorStoreConfig vectorStoreConfig, EmbeddingModel embeddingModel, WeaviateClient weaviateClient, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -318,6 +325,7 @@ public WeaviateVectorStore(WeaviateVectorStoreConfig vectorStoreConfig, Embeddin this.filterMetadataFields.stream().map(MetadataField::name).toList()); this.weaviateClient = weaviateClient; this.weaviateSimilaritySearchFields = buildWeaviateSimilaritySearchFields(); + this.batchingStrategy = batchingStrategy; } private Field[] buildWeaviateSimilaritySearchFields() { @@ -347,6 +355,8 @@ public void doAdd(List documents) { return; } + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List weaviateObjects = documents.stream().map(this::toWeaviateObject).toList(); Result response = this.weaviateClient.batch() @@ -385,11 +395,6 @@ public void doAdd(List documents) { private WeaviateObject toWeaviateObject(Document document) { - if (document.getEmbedding() == null || document.getEmbedding().length == 0) { - float[] embedding = this.embeddingModel.embed(document); - document.setEmbedding(embedding); - } - // https://weaviate.io/developers/weaviate/config-refs/datatypes Map fields = new HashMap<>(); fields.put(CONTENT_FIELD_NAME, document.getContent()); diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java index a2228cf35bb..20d04a97d9a 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreObservationIT.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -165,7 +166,8 @@ public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, Observatio .withConsistencyLevel(WeaviateVectorStoreConfig.ConsistentLevel.ONE) .build(); - return new WeaviateVectorStore(config, embeddingModel, weaviateClient, observationRegistry, null); + return new WeaviateVectorStore(config, embeddingModel, weaviateClient, observationRegistry, null, + new TokenCountBatchingStrategy()); }