From c39c13c33d6b0b304838ccf71df0c0ecab5d6af9 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Mon, 23 Sep 2024 21:37:04 -0400 Subject: [PATCH] GH-1199: Prevent timeouts with configurable batching for PgVectorStore inserts Resolves https://github.com/spring-projects/spring-ai/issues/1199 - Implement configurable maxDocumentBatchSize to prevent insert timeouts when adding large numbers of documents - Update PgVectorStore to process document inserts in controlled batches - Add maxDocumentBatchSize property to PgVectorStoreProperties - Update PgVectorStoreAutoConfiguration to use the new batching property - Add tests to verify batching behavior and performance This change addresses the issue of PgVectorStore inserts timing out due to large document volumes. By introducing configurable batching, users can now control the insert process to avoid timeouts while maintaining performance and reducing memory overhead for large-scale document additions. --- .../PgVectorStoreAutoConfiguration.java | 1 + .../pgvector/PgVectorStoreProperties.java | 11 +++ .../ai/vectorstore/PgVectorStore.java | 92 ++++++++++++------- .../ai/vectorstore/PgVectorStoreTests.java | 51 +++++++++- 4 files changed, 119 insertions(+), 36 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index 1b5a62507ae..ec4d76e0748 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null)) .withBatchingStrategy(batchingStrategy) + .withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java index b4554174617..47a12c36d3e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java @@ -24,6 +24,7 @@ /** * @author Christian Tzolov * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ @ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX) public class PgVectorStoreProperties extends CommonVectorStoreProperties { @@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION; + private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE; + public int getDimensions() { return dimensions; } @@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) { this.schemaValidation = schemaValidation; } + public int getMaxDocumentBatchSize() { + return this.maxDocumentBatchSize; + } + + public void setMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + } + } diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 8ab0546fc5f..8239ef0890d 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -18,10 +18,12 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; import org.postgresql.util.PGobject; import org.slf4j.Logger; @@ -81,6 +83,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); + public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000; + private final String vectorTableName; private final String vectorIndexName; @@ -109,6 +113,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini private final BatchingStrategy batchingStrategy; + private final int maxDocumentBatchSize; + public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false, PgIndexType.NONE, false); @@ -132,7 +138,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema); - } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, @@ -141,14 +146,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema, - ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); + ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE); } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, - BatchingStrategy batchingStrategy) { + BatchingStrategy batchingStrategy, int maxDocumentBatchSize) { super(observationRegistry, customObservationConvention); @@ -172,6 +177,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this.initializeSchema = initializeSchema; this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate); this.batchingStrategy = batchingStrategy; + this.maxDocumentBatchSize = maxDocumentBatchSize; } public PgDistanceType getDistanceType() { @@ -180,40 +186,50 @@ public PgDistanceType getDistanceType() { @Override public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - int size = documents.size(); + List> batchedDocuments = batchDocuments(documents); + batchedDocuments.forEach(this::insertOrUpdateBatch); + } - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + private List> batchDocuments(List documents) { + List> batches = new ArrayList<>(); + for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) { + batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size()))); + } + return batches; + } - this.jdbcTemplate.batchUpdate( - "INSERT INTO " + getFullyQualifiedTableName() - + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " - + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", - new BatchPreparedStatementSetter() { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - - var document = documents.get(i); - var content = document.getContent(); - var json = toJson(document.getMetadata()); - var embedding = document.getEmbedding(); - var pGvector = new PGvector(embedding); - - StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, - UUID.fromString(document.getId())); - StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); - StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); - } + private void insertOrUpdateBatch(List batch) { + String sql = "INSERT INTO " + getFullyQualifiedTableName() + + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " + + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? "; + + this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + + var document = batch.get(i); + var content = document.getContent(); + var json = toJson(document.getMetadata()); + var embedding = document.getEmbedding(); + var pGvector = new PGvector(embedding); + + StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, + UUID.fromString(document.getId())); + StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); + StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); + } - @Override - public int getBatchSize() { - return size; - } - }); + @Override + public int getBatchSize() { + return batch.size(); + } + }); } private String toJson(Map map) { @@ -509,6 +525,8 @@ public static class Builder { private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE; + @Nullable private VectorStoreObservationConvention searchObservationConvention; @@ -576,11 +594,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) { return this; } + public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + return this; + } + public PgVectorStore build() { return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled, this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType, this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema, - this.observationRegistry, this.searchObservationConvention, this.batchingStrategy); + this.observationRegistry, this.searchObservationConvention, this.batchingStrategy, + this.maxDocumentBatchSize); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java index f5b69a922c9..488dbd3f73e 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java @@ -15,15 +15,31 @@ */ package org.springframework.ai.vectorstore; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.ArgumentCaptor; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.Collections; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; /** * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ - public class PgVectorStoreTests { @ParameterizedTest(name = "{0} - Verifies valid Table name") @@ -53,8 +69,39 @@ public class PgVectorStoreTests { // 64 // characters }) - public void isValidTable(String tableName, Boolean expected) { + void isValidTable(String tableName, Boolean expected) { assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected); } + @Test + void shouldAddDocumentsInBatchesAndEmbedOnce() { + // Given + var jdbcTemplate = mock(JdbcTemplate.class); + var embeddingModel = mock(EmbeddingModel.class); + var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000) + .build(); + + // Testing with 9989 documents + var documents = Collections.nCopies(9989, new Document("foo")); + + // When + pgVectorStore.doAdd(documents); + + // Then + verify(embeddingModel, only()).embed(eq(documents), any(), any()); + + var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class); + verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture()); + + assertThat(batchUpdateCaptor.getAllValues()).hasSize(10) + .allSatisfy(BatchPreparedStatementSetter::getBatchSize) + .satisfies(batches -> { + for (int i = 0; i < 9; i++) { + assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i) + .isEqualTo(1000); + } + assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989); + }); + } + }