diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index 1b5a62507ae..ec4d76e0748 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null)) .withBatchingStrategy(batchingStrategy) + .withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java index b4554174617..47a12c36d3e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java @@ -24,6 +24,7 @@ /** * @author Christian Tzolov * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ @ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX) public class PgVectorStoreProperties extends CommonVectorStoreProperties { @@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION; + private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE; + public int getDimensions() { return dimensions; } @@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) { this.schemaValidation = schemaValidation; } + public int getMaxDocumentBatchSize() { + return this.maxDocumentBatchSize; + } + + public void setMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + } + } diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 8ab0546fc5f..8239ef0890d 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -18,10 +18,12 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; import org.postgresql.util.PGobject; import org.slf4j.Logger; @@ -81,6 +83,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); + public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000; + private final String vectorTableName; private final String vectorIndexName; @@ -109,6 +113,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini private final BatchingStrategy batchingStrategy; + private final int maxDocumentBatchSize; + public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false, PgIndexType.NONE, false); @@ -132,7 +138,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema); - } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, @@ -141,14 +146,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema, - ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); + ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE); } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, - BatchingStrategy batchingStrategy) { + BatchingStrategy batchingStrategy, int maxDocumentBatchSize) { super(observationRegistry, customObservationConvention); @@ -172,6 +177,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this.initializeSchema = initializeSchema; this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate); this.batchingStrategy = batchingStrategy; + this.maxDocumentBatchSize = maxDocumentBatchSize; } public PgDistanceType getDistanceType() { @@ -180,40 +186,50 @@ public PgDistanceType getDistanceType() { @Override public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - int size = documents.size(); + List> batchedDocuments = batchDocuments(documents); + batchedDocuments.forEach(this::insertOrUpdateBatch); + } - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + private List> batchDocuments(List documents) { + List> batches = new ArrayList<>(); + for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) { + batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size()))); + } + return batches; + } - this.jdbcTemplate.batchUpdate( - "INSERT INTO " + getFullyQualifiedTableName() - + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " - + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", - new BatchPreparedStatementSetter() { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - - var document = documents.get(i); - var content = document.getContent(); - var json = toJson(document.getMetadata()); - var embedding = document.getEmbedding(); - var pGvector = new PGvector(embedding); - - StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, - UUID.fromString(document.getId())); - StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); - StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); - } + private void insertOrUpdateBatch(List batch) { + String sql = "INSERT INTO " + getFullyQualifiedTableName() + + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " + + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? "; + + this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + + var document = batch.get(i); + var content = document.getContent(); + var json = toJson(document.getMetadata()); + var embedding = document.getEmbedding(); + var pGvector = new PGvector(embedding); + + StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, + UUID.fromString(document.getId())); + StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); + StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); + } - @Override - public int getBatchSize() { - return size; - } - }); + @Override + public int getBatchSize() { + return batch.size(); + } + }); } private String toJson(Map map) { @@ -509,6 +525,8 @@ public static class Builder { private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE; + @Nullable private VectorStoreObservationConvention searchObservationConvention; @@ -576,11 +594,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) { return this; } + public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + return this; + } + public PgVectorStore build() { return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled, this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType, this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema, - this.observationRegistry, this.searchObservationConvention, this.batchingStrategy); + this.observationRegistry, this.searchObservationConvention, this.batchingStrategy, + this.maxDocumentBatchSize); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java index f5b69a922c9..488dbd3f73e 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java @@ -15,15 +15,31 @@ */ package org.springframework.ai.vectorstore; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.ArgumentCaptor; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.Collections; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; /** * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ - public class PgVectorStoreTests { @ParameterizedTest(name = "{0} - Verifies valid Table name") @@ -53,8 +69,39 @@ public class PgVectorStoreTests { // 64 // characters }) - public void isValidTable(String tableName, Boolean expected) { + void isValidTable(String tableName, Boolean expected) { assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected); } + @Test + void shouldAddDocumentsInBatchesAndEmbedOnce() { + // Given + var jdbcTemplate = mock(JdbcTemplate.class); + var embeddingModel = mock(EmbeddingModel.class); + var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000) + .build(); + + // Testing with 9989 documents + var documents = Collections.nCopies(9989, new Document("foo")); + + // When + pgVectorStore.doAdd(documents); + + // Then + verify(embeddingModel, only()).embed(eq(documents), any(), any()); + + var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class); + verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture()); + + assertThat(batchUpdateCaptor.getAllValues()).hasSize(10) + .allSatisfy(BatchPreparedStatementSetter::getBatchSize) + .satisfies(batches -> { + for (int i = 0; i < 9; i++) { + assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i) + .isEqualTo(1000); + } + assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989); + }); + } + }