diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java index 0c67399c51a..c3fcd3b114c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.azure; import com.azure.core.credential.AzureKeyCredential; @@ -23,7 +24,9 @@ import java.util.List; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.azure.AzureVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; @@ -36,6 +39,7 @@ /** * @author Christian Tzolov + * @author Soby Chacko */ @AutoConfiguration @ConditionalOnClass({ EmbeddingModel.class, SearchIndexClient.class, AzureVectorStore.class }) @@ -51,15 +55,22 @@ public SearchIndexClient searchIndexClient(AzureVectorStoreProperties properties .buildClient(); } + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public AzureVectorStore vectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel, AzureVectorStoreProperties properties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { var vectorStore = new AzureVectorStore(searchIndexClient, embeddingModel, properties.isInitializeSchema(), List.of(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); vectorStore.setIndexName(properties.getIndexName()); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java index 3d014f4be5b..0431133b985 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.cassandra; import java.time.Duration; @@ -22,7 +23,9 @@ import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.CassandraVectorStore; import org.springframework.ai.vectorstore.CassandraVectorStoreConfig; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; @@ -38,6 +41,7 @@ /** * @author Mick Semb Wever * @author Christian Tzolov + * @author Soby Chacko * @since 1.0.0 */ @AutoConfiguration(after = CassandraAutoConfiguration.class) @@ -45,11 +49,18 @@ @EnableConfigurationProperties(CassandraVectorStoreProperties.class) public class CassandraVectorStoreAutoConfiguration { + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, CassandraVectorStoreProperties properties, CqlSession cqlSession, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { var builder = CassandraVectorStoreConfig.builder().withCqlSession(cqlSession); @@ -69,7 +80,7 @@ public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, Cassandra return new CassandraVectorStore(builder.build(), embeddingModel, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java index d74de9d5c02..f9ed2f8ce6e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java @@ -16,7 +16,9 @@ package org.springframework.ai.autoconfigure.vectorstore.gemfire; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.GemFireVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; @@ -47,11 +49,18 @@ GemFireVectorStoreAutoConfiguration.PropertiesGemFireConnectionDetails gemfireCo return new GemFireVectorStoreAutoConfiguration.PropertiesGemFireConnectionDetails(properties); } + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public GemFireVectorStore gemfireVectorStore(EmbeddingModel embeddingModel, GemFireVectorStoreProperties properties, GemFireConnectionDetails gemFireConnectionDetails, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { var builder = new GemFireVectorStore.GemFireVectorStoreConfig.Builder(); builder.setHost(gemFireConnectionDetails.getHost()) @@ -65,7 +74,7 @@ public GemFireVectorStore gemfireVectorStore(EmbeddingModel embeddingModel, GemF .setSslEnabled(properties.isSslEnabled()); return new GemFireVectorStore(builder.build(), embeddingModel, properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } private static class PropertiesGemFireConnectionDetails implements GemFireConnectionDetails { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java index bdae34efb40..5b7f4ddab72 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mongo/MongoDBAtlasVectorStoreAutoConfiguration.java @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.mongo; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.MongoDBAtlasVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; @@ -45,11 +48,18 @@ @EnableConfigurationProperties(MongoDBAtlasVectorStoreProperties.class) public class MongoDBAtlasVectorStoreAutoConfiguration { + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean MongoDBAtlasVectorStore vectorStore(MongoTemplate mongoTemplate, EmbeddingModel embeddingModel, MongoDBAtlasVectorStoreProperties properties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { var builder = MongoDBAtlasVectorStore.MongoDBVectorStoreConfig.builder(); @@ -66,7 +76,7 @@ MongoDBAtlasVectorStore vectorStore(MongoTemplate mongoTemplate, EmbeddingModel return new MongoDBAtlasVectorStore(mongoTemplate, embeddingModel, config, properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java index eeeeaf56ccf..fe19c255205 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.opensearch; import org.apache.hc.client5.http.auth.AuthScope; @@ -24,7 +25,10 @@ import org.opensearch.client.transport.aws.AwsSdk2Transport; import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; + +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.OpenSearchVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; @@ -58,17 +62,24 @@ PropertiesOpenSearchConnectionDetails openSearchConnectionDetails(OpenSearchVect return new PropertiesOpenSearchConnectionDetails(properties); } + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean OpenSearchVectorStore vectorStore(OpenSearchVectorStoreProperties properties, OpenSearchClient openSearchClient, EmbeddingModel embeddingModel, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { var indexName = Optional.ofNullable(properties.getIndexName()).orElse(OpenSearchVectorStore.DEFAULT_INDEX_NAME); var mappingJson = Optional.ofNullable(properties.getMappingJson()) .orElse(OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536); return new OpenSearchVectorStore(indexName, openSearchClient, embeddingModel, mappingJson, properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } @Configuration(proxyBeanMethods = false) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java index 3d3397dbf40..837e278c2b8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.oracle; import javax.sql.DataSource; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.OracleVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; @@ -35,22 +38,30 @@ * @author Loïc Lefèvre * @author Eddú Meléndez * @author Christian Tzolov + * @author Soby Chacko */ @AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) @ConditionalOnClass({ OracleVectorStore.class, DataSource.class, JdbcTemplate.class }) @EnableConfigurationProperties(OracleVectorStoreProperties.class) public class OracleVectorStoreAutoConfiguration { + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public OracleVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, OracleVectorStoreProperties properties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { return new OracleVectorStore(jdbcTemplate, embeddingModel, properties.getTableName(), properties.getIndexType(), properties.getDistanceType(), properties.getDimensions(), properties.getSearchAccuracy(), properties.isInitializeSchema(), properties.isRemoveExistingVectorStoreTable(), properties.isForcedNormalization(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java index 9756b88abc0..058b62841aa 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.pinecone; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.PineconeVectorStore; import org.springframework.ai.vectorstore.PineconeVectorStore.PineconeVectorStoreConfig; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; @@ -30,17 +33,25 @@ /** * @author Christian Tzolov + * @author Soby Chacko */ @AutoConfiguration @ConditionalOnClass({ PineconeVectorStore.class, EmbeddingModel.class }) @EnableConfigurationProperties(PineconeVectorStoreProperties.class) public class PineconeVectorStoreAutoConfiguration { + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public PineconeVectorStore vectorStore(EmbeddingModel embeddingModel, PineconeVectorStoreProperties properties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { var config = PineconeVectorStoreConfig.builder() .withApiKey(properties.getApiKey()) @@ -55,7 +66,7 @@ public PineconeVectorStore vectorStore(EmbeddingModel embeddingModel, PineconeVe return new PineconeVectorStore(config, embeddingModel, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null)); + customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } } diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index 24a7540e69a..4a5ea34459d 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore.azure; import java.util.ArrayList; @@ -25,7 +26,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -73,6 +77,7 @@ * @author Christian Tzolov * @author Josh Long * @author Thomas Vitale + * @author Soby Chacko */ public class AzureVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -116,6 +121,8 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + /** * List of metadata fields (as field name and type) that can be used in similarity * search query filter expressions. The {@link Document#getMetadata()} can contain @@ -175,7 +182,8 @@ public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embe */ public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel, boolean initializeSchema, List filterMetadataFields) { - this(searchIndexClient, embeddingModel, initializeSchema, filterMetadataFields, ObservationRegistry.NOOP, null); + this(searchIndexClient, embeddingModel, initializeSchema, filterMetadataFields, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } /** @@ -191,7 +199,7 @@ public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embe */ public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embeddingModel, boolean initializeSchema, List filterMetadataFields, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -204,6 +212,7 @@ public AzureVectorStore(SearchIndexClient searchIndexClient, EmbeddingModel embe this.embeddingModel = embeddingModel; this.filterMetadataFields = filterMetadataFields; this.filterExpressionConverter = new AzureAiSearchFilterExpressionConverter(filterMetadataFields); + this.batchingStrategy = batchingStrategy; } /** @@ -243,11 +252,12 @@ public void doAdd(List documents) { return; // nothing to do; } + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + final var searchDocuments = documents.stream().map(document -> { - final var embeddings = this.embeddingModel.embed(document); SearchDocument searchDocument = new SearchDocument(); searchDocument.put(ID_FIELD_NAME, document.getId()); - searchDocument.put(EMBEDDING_FIELD_NAME, embeddings); + searchDocument.put(EMBEDDING_FIELD_NAME, document.getEmbedding()); searchDocument.put(CONTENT_FIELD_NAME, document.getContent()); searchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString()); diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index ce58a246c3b..9b980817720 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,9 +80,9 @@ public void addAndSearchTest() { vectorStore.add(documents); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); - }, hasSize(1)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)), + hasSize(1)); List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); @@ -97,9 +97,8 @@ public void addAndSearchTest() { // Remove all documents from the store vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); - }, hasSize(0)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); }); } @@ -206,9 +205,7 @@ public void documentUpdateTest() { SearchRequest springSearchRequest = SearchRequest.query("Spring").withTopK(5); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(springSearchRequest); - }, hasSize(1)); + Awaitility.await().until(() -> vectorStore.similaritySearch(springSearchRequest), hasSize(1)); List results = vectorStore.similaritySearch(springSearchRequest); @@ -227,9 +224,9 @@ public void documentUpdateTest() { SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(); - }, equalTo("The World is Big and Salvation Lurks Around the Corner")); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(), + equalTo("The World is Big and Salvation Lurks Around the Corner")); results = vectorStore.similaritySearch(fooBarSearchRequest); @@ -242,9 +239,7 @@ public void documentUpdateTest() { // Remove all documents from the store vectorStore.delete(List.of(document.getId())); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(fooBarSearchRequest); - }, hasSize(0)); + Awaitility.await().until(() -> vectorStore.similaritySearch(fooBarSearchRequest), hasSize(0)); }); } @@ -258,10 +253,10 @@ public void searchThresholdTest() { vectorStore.add(documents); - Awaitility.await().until(() -> { - return vectorStore - .similaritySearch(SearchRequest.query("Depression").withTopK(50).withSimilarityThresholdAll()); - }, hasSize(3)); + Awaitility.await() + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("Depression").withTopK(50).withSimilarityThresholdAll()), + hasSize(3)); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); @@ -284,9 +279,8 @@ public void searchThresholdTest() { // Remove all documents from the store vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); - }, hasSize(0)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); }); } diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java index cced6b5f853..72f6c81294f 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreObservationIT.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.vectorstore.SearchRequest; @@ -179,7 +180,7 @@ public VectorStore vectorStore(SearchIndexClient searchIndexClient, EmbeddingMod var filterableMetaFields = List.of(MetadataField.text("country"), MetadataField.int64("year"), MetadataField.date("activationDate")); return new AzureVectorStore(searchIndexClient, embeddingModel, true, filterableMetaFields, - observationRegistry, null); + observationRegistry, null, new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 6c8cf70538e..35bb49420f3 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import com.datastax.oss.driver.api.core.cql.BoundStatement; @@ -35,7 +36,10 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -96,6 +100,7 @@ * @author Mick Semb Wever * @author Christian Tzolov * @author Thomas Vitale + * @author Soby Chacko * @see VectorStore * @see org.springframework.ai.vectorstore.CassandraVectorStoreConfig * @see EmbeddingModel @@ -137,12 +142,15 @@ public enum Similarity { private final Similarity similarity; + private final BatchingStrategy batchingStrategy; + public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embeddingModel) { - this(conf, embeddingModel, ObservationRegistry.NOOP, null); + this(conf, embeddingModel, ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); } public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embeddingModel, - ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) { + ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, + BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -166,21 +174,20 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embe this.filterExpressionConverter = new CassandraFilterExpressionConverter( cassandraMetadata.getColumns().values()); + this.batchingStrategy = batchingStrategy; } @Override public void doAdd(List documents) { var futures = new CompletableFuture[documents.size()]; + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + int i = 0; for (Document d : documents) { futures[i++] = CompletableFuture.runAsync(() -> { List primaryKeyValues = this.conf.documentIdTranslator.apply(d.getId()); - if (null == d.getEmbedding() || d.getEmbedding().length == 0) { - d.setEmbedding(this.embeddingModel.embed(d)); - } - BoundStatementBuilder builder = prepareAddStatement(d.getMetadata().keySet()).boundStatementBuilder(); for (int k = 0; k < primaryKeyValues.size(); ++k) { SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java index 3216b3a6715..1e3ac241e33 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreObservationIT.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn; @@ -173,7 +174,8 @@ public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddin .build(); conf.dropKeyspace(); - return new CassandraVectorStore(conf, embeddingModel, observationRegistry, null); + return new CassandraVectorStore(conf, embeddingModel, observationRegistry, null, + new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java index 390de98f748..05e6a01cff1 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import static org.springframework.http.HttpStatus.BAD_REQUEST; @@ -26,7 +27,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; @@ -75,6 +79,8 @@ public class GemFireVectorStore extends AbstractObservationVectorStore implement private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + /** * Configures and initializes a GemFireVectorStore instance based on the provided * configuration. @@ -84,7 +90,8 @@ public class GemFireVectorStore extends AbstractObservationVectorStore implement */ public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embeddingModel, boolean initializeSchema) { - this(config, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null); + this(config, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } /** @@ -99,7 +106,8 @@ public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embedd * observing operations */ public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embeddingModel, boolean initializeSchema, - ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) { + ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, + BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -118,6 +126,7 @@ public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embedd .build(config.sslEnabled ? "s" : "", config.host, config.port) .toString(); this.client = WebClient.create(base); + this.batchingStrategy = batchingStrategy; } // Create Index Parameters @@ -404,13 +413,11 @@ public void setDeleteData(boolean deleteData) { @Override public void doAdd(List documents) { - UploadRequest upload = new UploadRequest(documents.stream().map(document -> { - // Compute and assign an embedding to the document. - document.setEmbedding(this.embeddingModel.embed(document)); - float[] floatVector = document.getEmbedding(); - return new UploadRequest.Embedding(document.getId(), floatVector, DOCUMENT_FIELD, document.getContent(), - document.getMetadata()); - }).toList()); + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + UploadRequest upload = new UploadRequest(documents.stream() + .map(document -> new UploadRequest.Embedding(document.getId(), document.getEmbedding(), DOCUMENT_FIELD, + document.getContent(), document.getMetadata())) + .toList()); ObjectMapper objectMapper = new ObjectMapper(); String embeddingsJson = null; diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java index db2f370ad25..6fb44d38561 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -199,7 +200,7 @@ public GemFireVectorStore vectorStore(EmbeddingModel embeddingModel, Observation .setHost("localhost") .setPort(HTTP_SERVICE_PORT) .setIndexName(TEST_INDEX_NAME) - .build(), embeddingModel, true, observationRegistry, null); + .build(), embeddingModel, true, observationRegistry, null, new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java index a3c745ae2b3..aa16e3f29be 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import static org.springframework.data.mongodb.core.query.Criteria.where; @@ -24,7 +25,10 @@ import java.util.Optional; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; @@ -81,6 +85,8 @@ public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore impl private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + public MongoDBAtlasVectorStore(MongoTemplate mongoTemplate, EmbeddingModel embeddingModel, boolean initializeSchema) { this(mongoTemplate, embeddingModel, MongoDBVectorStoreConfig.defaultConfig(), initializeSchema); @@ -88,12 +94,13 @@ public MongoDBAtlasVectorStore(MongoTemplate mongoTemplate, EmbeddingModel embed public MongoDBAtlasVectorStore(MongoTemplate mongoTemplate, EmbeddingModel embeddingModel, MongoDBVectorStoreConfig config, boolean initializeSchema) { - this(mongoTemplate, embeddingModel, config, initializeSchema, ObservationRegistry.NOOP, null); + this(mongoTemplate, embeddingModel, config, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } public MongoDBAtlasVectorStore(MongoTemplate mongoTemplate, EmbeddingModel embeddingModel, MongoDBVectorStoreConfig config, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -102,6 +109,7 @@ public MongoDBAtlasVectorStore(MongoTemplate mongoTemplate, EmbeddingModel embed this.config = config; this.initializeSchema = initializeSchema; + this.batchingStrategy = batchingStrategy; } @Override @@ -175,9 +183,8 @@ private Document mapMongoDocument(org.bson.Document mongoDocument, float[] query @Override public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); for (Document document : documents) { - float[] embedding = this.embeddingModel.embed(document); - document.setEmbedding(embedding); this.mongoTemplate.save(document, this.config.collectionName); } } diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java index 0d0ae6994e8..8adf17bf0cc 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbVectorStoreObservationIT.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -185,7 +186,7 @@ public VectorStore vectorStore(MongoTemplate mongoTemplate, EmbeddingModel embed MongoDBAtlasVectorStore.MongoDBVectorStoreConfig.builder() .withMetadataFieldsToFilter(List.of("country", "year")) .build(), - true, observationRegistry, null); + true, observationRegistry, null, new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java index 5bb55240742..1135241cf4e 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.opensearch.client.json.JsonData; @@ -30,7 +31,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.filter.Filter; @@ -91,6 +95,8 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + public OpenSearchVectorStore(OpenSearchClient openSearchClient, EmbeddingModel embeddingModel, boolean initializeSchema) { this(openSearchClient, embeddingModel, DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536, @@ -104,12 +110,13 @@ public OpenSearchVectorStore(OpenSearchClient openSearchClient, EmbeddingModel e public OpenSearchVectorStore(String index, OpenSearchClient openSearchClient, EmbeddingModel embeddingModel, String mappingJson, boolean initializeSchema) { - this(index, openSearchClient, embeddingModel, mappingJson, initializeSchema, ObservationRegistry.NOOP, null); + this(index, openSearchClient, embeddingModel, mappingJson, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } public OpenSearchVectorStore(String index, OpenSearchClient openSearchClient, EmbeddingModel embeddingModel, String mappingJson, boolean initializeSchema, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -124,6 +131,7 @@ public OpenSearchVectorStore(String index, OpenSearchClient openSearchClient, Em // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces this.similarityFunction = COSINE_SIMILARITY_FUNCTION; this.initializeSchema = initializeSchema; + this.batchingStrategy = batchingStrategy; } public OpenSearchVectorStore withSimilarityFunction(String similarityFunction) { @@ -133,12 +141,9 @@ public OpenSearchVectorStore withSimilarityFunction(String similarityFunction) { @Override public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); for (Document document : documents) { - if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().length == 0) { - logger.debug("Calling EmbeddingModel for document id = " + document.getId()); - document.setEmbedding(this.embeddingModel.embed(document)); - } bulkRequestBuilder .operations(op -> op.index(idx -> idx.index(this.index).id(document.getId()).document(document))); } diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java index 35077fc2a78..dbb417d24ee 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import org.apache.hc.core5.http.HttpHost; diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java index 8db248088e3..5e3faed1c66 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreObservationIT.java @@ -36,6 +36,7 @@ import org.opensearch.testcontainers.OpensearchContainer; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -207,7 +208,7 @@ public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel, .builder(HttpHost.create(opensearchContainer.getHttpHostAddress())) .build()), embeddingModel, OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536, - true, observationRegistry, null); + true, observationRegistry, null, new TokenCountBatchingStrategy()); } catch (URISyntaxException e) { throw new RuntimeException(e); diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java index 260f4c49900..290f570f946 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import static org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT; @@ -33,7 +34,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; @@ -78,6 +82,7 @@ * * @author Loïc Lefèvre * @author Christian Tzolov + * @author Soby Chacko */ public class OracleVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -215,6 +220,8 @@ public enum OracleVectorStoreDistanceType { private final int searchAccuracy; + private final BatchingStrategy batchingStrategy; + public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { this(jdbcTemplate, embeddingModel, DEFAULT_TABLE_NAME, DEFAULT_INDEX_TYPE, DEFAULT_DISTANCE_TYPE, DEFAULT_DIMENSIONS, DEFAULT_SEARCH_ACCURACY, false, false, false); @@ -230,14 +237,15 @@ public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMode int searchAccuracy, boolean initializeSchema, boolean removeExistingVectorStoreTable, boolean forcedNormalization) { this(jdbcTemplate, embeddingModel, tableName, indexType, distanceType, dimensions, searchAccuracy, - initializeSchema, removeExistingVectorStoreTable, forcedNormalization, ObservationRegistry.NOOP, null); + initializeSchema, removeExistingVectorStoreTable, forcedNormalization, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy()); } public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, String tableName, OracleVectorStoreIndexType indexType, OracleVectorStoreDistanceType distanceType, int dimensions, int searchAccuracy, boolean initializeSchema, boolean removeExistingVectorStoreTable, boolean forcedNormalization, ObservationRegistry observationRegistry, - VectorStoreObservationConvention customObservationConvention) { + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -269,17 +277,19 @@ public OracleVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingMode this.initializeSchema = initializeSchema; this.removeExistingVectorStoreTable = removeExistingVectorStoreTable; this.forcedNormalization = forcedNormalization; + this.batchingStrategy = batchingStrategy; } @Override public void doAdd(final List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); this.jdbcTemplate.batchUpdate(getIngestStatement(), new BatchPreparedStatementSetter() { @Override public void setValues(PreparedStatement ps, int i) throws SQLException { final Document document = documents.get(i); final String content = document.getContent(); final byte[] json = toJson(document.getMetadata()); - final VECTOR embeddingVector = toVECTOR(embeddingModel.embed(document)); + final VECTOR embeddingVector = toVECTOR(document.getEmbedding()); setParameterValue(ps, 1, Types.VARCHAR, document.getId()); setParameterValue(ps, 2, Types.VARCHAR, content); diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java index 851cf347a2b..534e026b72c 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreObservationIT.java @@ -27,6 +27,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -183,7 +184,8 @@ public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddi ObservationRegistry observationRegistry) { return new OracleVectorStore(jdbcTemplate, embeddingModel, OracleVectorStore.DEFAULT_TABLE_NAME, OracleVectorStore.OracleVectorStoreIndexType.IVF, OracleVectorStoreDistanceType.COSINE, 384, - OracleVectorStore.DEFAULT_SEARCH_ACCURACY, true, true, true, observationRegistry, null); + OracleVectorStore.DEFAULT_SEARCH_ACCURACY, true, true, true, observationRegistry, null, + new TokenCountBatchingStrategy()); } @Bean diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java index fdb96029967..59a7f7ff8e1 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,10 @@ import java.util.Optional; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; @@ -57,6 +60,7 @@ * * @author Christian Tzolov * @author Adam Bchouti + * @author Soby Chacko */ public class PineconeVectorStore extends AbstractObservationVectorStore { @@ -80,6 +84,8 @@ public class PineconeVectorStore extends AbstractObservationVectorStore { private final ObjectMapper objectMapper; + private final BatchingStrategy batchingStrategy; + /** * Configuration class for the PineconeVectorStore. */ @@ -260,7 +266,7 @@ public PineconeVectorStoreConfig build() { * @param embeddingModel The client for embedding operations. */ public PineconeVectorStore(PineconeVectorStoreConfig config, EmbeddingModel embeddingModel) { - this(config, embeddingModel, ObservationRegistry.NOOP, null); + this(config, embeddingModel, ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); } /** @@ -271,7 +277,8 @@ public PineconeVectorStore(PineconeVectorStoreConfig config, EmbeddingModel embe * @param customObservationConvention The custom observation convention. */ public PineconeVectorStore(PineconeVectorStoreConfig config, EmbeddingModel embeddingModel, - ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) { + ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, + BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); Assert.notNull(config, "PineconeVectorStoreConfig must not be null"); Assert.notNull(embeddingModel, "EmbeddingModel must not be null"); @@ -283,6 +290,7 @@ public PineconeVectorStore(PineconeVectorStoreConfig config, EmbeddingModel embe this.pineconeDistanceMetadataFieldName = config.distanceMetadataFieldName; this.pineconeConnection = new PineconeClient(config.clientConfig).connect(config.connectionConfig); this.objectMapper = new ObjectMapper(); + this.batchingStrategy = batchingStrategy; } /** @@ -291,17 +299,14 @@ public PineconeVectorStore(PineconeVectorStoreConfig config, EmbeddingModel embe * @param namespace The namespace to add the documents to */ public void add(List documents, String namespace) { - - List upsertVectors = documents.stream().map(document -> { - // Compute and assign an embedding to the document. - document.setEmbedding(this.embeddingModel.embed(document)); - - return Vector.newBuilder() + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + List upsertVectors = documents.stream() + .map(document -> Vector.newBuilder() .setId(document.getId()) .addAllValues(EmbeddingUtils.toList(document.getEmbedding())) .setMetadata(metadataToStruct(document)) - .build(); - }).toList(); + .build()) + .toList(); UpsertRequest upsertRequest = UpsertRequest.newBuilder() .addAllVectors(upsertVectors) diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java index 5ad55e50801..ee64e809e8e 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.transformers.TransformersEmbeddingModel; @@ -196,7 +197,8 @@ public PineconeVectorStoreConfig pineconeVectorStoreConfig() { @Bean public VectorStore vectorStore(PineconeVectorStoreConfig config, EmbeddingModel embeddingModel, ObservationRegistry observationRegistry) { - return new PineconeVectorStore(config, embeddingModel, observationRegistry, null); + return new PineconeVectorStore(config, embeddingModel, observationRegistry, null, + new TokenCountBatchingStrategy()); } @Bean