Skip to content

Commit 2d910e4

Browse files
committed
GH-1199: Prevent timeouts with configurable batching for PgVectorStore inserts
Resolves #1199 - Implement configurable maxDocumentBatchSize to prevent insert timeouts when adding large numbers of documents - Update PgVectorStore to process document inserts in controlled batches - Add maxDocumentBatchSize property to PgVectorStoreProperties - Update PgVectorStoreAutoConfiguration to use the new batching property - Add tests to verify batching behavior and performance This change addresses the issue of PgVectorStore inserts timing out due to large document volumes. By introducing configurable batching, users can now control the insert process to avoid timeouts while maintaining performance and reducing memory overhead for large-scale document additions.
1 parent acb31e7 commit 2d910e4

File tree

5 files changed

+142
-36
lines changed

5 files changed

+142
-36
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed
7171
.withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
7272
.withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null))
7373
.withBatchingStrategy(batchingStrategy)
74+
.withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize())
7475
.build();
7576
}
7677

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
/**
2525
* @author Christian Tzolov
2626
* @author Muthukumaran Navaneethakrishnan
27+
* @author Soby Chacko
2728
*/
2829
@ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX)
2930
public class PgVectorStoreProperties extends CommonVectorStoreProperties {
@@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties {
4546

4647
private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;
4748

49+
private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE;
50+
4851
public int getDimensions() {
4952
return dimensions;
5053
}
@@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) {
101104
this.schemaValidation = schemaValidation;
102105
}
103106

107+
public int getMaxDocumentBatchSize() {
108+
return this.maxDocumentBatchSize;
109+
}
110+
111+
public void setMaxDocumentBatchSize(int maxDocumentBatchSize) {
112+
this.maxDocumentBatchSize = maxDocumentBatchSize;
113+
}
114+
104115
}

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
import java.sql.PreparedStatement;
1919
import java.sql.ResultSet;
2020
import java.sql.SQLException;
21+
import java.util.ArrayList;
2122
import java.util.List;
2223
import java.util.Map;
2324
import java.util.Optional;
2425
import java.util.UUID;
26+
import java.util.concurrent.atomic.AtomicInteger;
2527

2628
import org.postgresql.util.PGobject;
2729
import org.slf4j.Logger;
@@ -81,6 +83,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
8183

8284
public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter();
8385

86+
public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000;
87+
8488
private final String vectorTableName;
8589

8690
private final String vectorIndexName;
@@ -109,6 +113,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
109113

110114
private final BatchingStrategy batchingStrategy;
111115

116+
private final int maxDocumentBatchSize;
117+
112118
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
113119
this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false,
114120
PgIndexType.NONE, false);
@@ -132,7 +138,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin
132138

133139
this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions,
134140
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema);
135-
136141
}
137142

138143
private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
@@ -141,14 +146,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
141146

142147
this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions,
143148
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema,
144-
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy());
149+
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE);
145150
}
146151

147152
private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
148153
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
149154
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema,
150155
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention,
151-
BatchingStrategy batchingStrategy) {
156+
BatchingStrategy batchingStrategy, int maxDocumentBatchSize) {
152157

153158
super(observationRegistry, customObservationConvention);
154159

@@ -172,6 +177,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
172177
this.initializeSchema = initializeSchema;
173178
this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate);
174179
this.batchingStrategy = batchingStrategy;
180+
this.maxDocumentBatchSize = maxDocumentBatchSize;
175181
}
176182

177183
public PgDistanceType getDistanceType() {
@@ -180,40 +186,50 @@ public PgDistanceType getDistanceType() {
180186

181187
@Override
182188
public void doAdd(List<Document> documents) {
189+
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
183190

184-
int size = documents.size();
191+
List<List<Document>> batchedDocuments = batchDocuments(documents);
192+
batchedDocuments.forEach(this::insertOrUpdateBatch);
193+
}
185194

186-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
195+
private List<List<Document>> batchDocuments(List<Document> documents) {
196+
List<List<Document>> batches = new ArrayList<>();
197+
for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) {
198+
batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size())));
199+
}
200+
return batches;
201+
}
187202

188-
this.jdbcTemplate.batchUpdate(
189-
"INSERT INTO " + getFullyQualifiedTableName()
190-
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
191-
+ "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ",
192-
new BatchPreparedStatementSetter() {
193-
@Override
194-
public void setValues(PreparedStatement ps, int i) throws SQLException {
195-
196-
var document = documents.get(i);
197-
var content = document.getContent();
198-
var json = toJson(document.getMetadata());
199-
var embedding = document.getEmbedding();
200-
var pGvector = new PGvector(embedding);
201-
202-
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
203-
UUID.fromString(document.getId()));
204-
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
205-
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
206-
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
207-
StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
208-
StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
209-
StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
210-
}
203+
private void insertOrUpdateBatch(List<Document> batch) {
204+
String sql = "INSERT INTO " + getFullyQualifiedTableName()
205+
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
206+
+ "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ";
207+
208+
this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
209+
@Override
210+
public void setValues(PreparedStatement ps, int i) throws SQLException {
211+
212+
var document = batch.get(i);
213+
var content = document.getContent();
214+
var json = toJson(document.getMetadata());
215+
var embedding = document.getEmbedding();
216+
var pGvector = new PGvector(embedding);
217+
218+
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
219+
UUID.fromString(document.getId()));
220+
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
221+
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
222+
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
223+
StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
224+
StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
225+
StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
226+
}
211227

212-
@Override
213-
public int getBatchSize() {
214-
return size;
215-
}
216-
});
228+
@Override
229+
public int getBatchSize() {
230+
return batch.size();
231+
}
232+
});
217233
}
218234

219235
private String toJson(Map<String, Object> map) {
@@ -509,6 +525,8 @@ public static class Builder {
509525

510526
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
511527

528+
private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE;
529+
512530
@Nullable
513531
private VectorStoreObservationConvention searchObservationConvention;
514532

@@ -576,11 +594,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
576594
return this;
577595
}
578596

597+
public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) {
598+
this.maxDocumentBatchSize = maxDocumentBatchSize;
599+
return this;
600+
}
601+
579602
public PgVectorStore build() {
580603
return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled,
581604
this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType,
582605
this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema,
583-
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy);
606+
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy,
607+
this.maxDocumentBatchSize);
584608
}
585609

586610
}

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,22 @@
2020
import org.mockito.Mock;
2121
import org.mockito.junit.jupiter.MockitoExtension;
2222

23+
import org.springframework.ai.document.Document;
2324
import org.springframework.ai.embedding.EmbeddingModel;
25+
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
2426
import org.springframework.jdbc.core.JdbcTemplate;
2527

2628
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.mockito.ArgumentMatchers.any;
2730
import static org.mockito.Mockito.never;
2831
import static org.mockito.Mockito.only;
32+
import static org.mockito.Mockito.times;
2933
import static org.mockito.Mockito.verify;
3034
import static org.mockito.Mockito.when;
3135

36+
import java.util.ArrayList;
37+
import java.util.List;
38+
3239
/**
3340
* @author Christian Tzolov
3441
*/
@@ -74,4 +81,20 @@ public void fallBackToDefaultDimensions() {
7481
verify(embeddingModel, only()).dimensions();
7582
}
7683

84+
@Test
85+
void foo() {
86+
PgVectorStore build = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(10)
87+
.build();
88+
List<Document> documents = new ArrayList<>();
89+
for (int i = 0; i < 97; i++) {
90+
documents.add(new Document("foo"));
91+
}
92+
93+
build.doAdd(documents);
94+
95+
verify(embeddingModel, only()).embed(any(), any(), any());
96+
97+
verify(jdbcTemplate, times(10)).batchUpdate(any(String.class), any(BatchPreparedStatementSetter.class));
98+
}
99+
77100
}

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,31 @@
1515
*/
1616
package org.springframework.ai.vectorstore;
1717

18+
import org.junit.jupiter.api.Test;
1819
import org.junit.jupiter.params.ParameterizedTest;
1920
import org.junit.jupiter.params.provider.CsvSource;
21+
import org.mockito.ArgumentCaptor;
2022

2123
import static org.assertj.core.api.Assertions.assertThat;
24+
import static org.mockito.ArgumentMatchers.any;
25+
import static org.mockito.ArgumentMatchers.anyString;
26+
import static org.mockito.ArgumentMatchers.eq;
27+
import static org.mockito.Mockito.mock;
28+
import static org.mockito.Mockito.only;
29+
import static org.mockito.Mockito.times;
30+
import static org.mockito.Mockito.verify;
31+
32+
import java.util.Collections;
33+
34+
import org.springframework.ai.document.Document;
35+
import org.springframework.ai.embedding.EmbeddingModel;
36+
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
37+
import org.springframework.jdbc.core.JdbcTemplate;
2238

2339
/**
2440
* @author Muthukumaran Navaneethakrishnan
41+
* @author Soby Chacko
2542
*/
26-
2743
public class PgVectorStoreTests {
2844

2945
@ParameterizedTest(name = "{0} - Verifies valid Table name")
@@ -53,8 +69,39 @@ public class PgVectorStoreTests {
5369
// 64
5470
// characters
5571
})
56-
public void isValidTable(String tableName, Boolean expected) {
72+
void isValidTable(String tableName, Boolean expected) {
5773
assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected);
5874
}
5975

76+
@Test
77+
void shouldAddDocumentsInBatchesAndEmbedOnce() {
78+
// Given
79+
var jdbcTemplate = mock(JdbcTemplate.class);
80+
var embeddingModel = mock(EmbeddingModel.class);
81+
var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000)
82+
.build();
83+
84+
// Testing with 9989 documents
85+
var documents = Collections.nCopies(9989, new Document("foo"));
86+
87+
// When
88+
pgVectorStore.doAdd(documents);
89+
90+
// Then
91+
verify(embeddingModel, only()).embed(eq(documents), any(), any());
92+
93+
var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class);
94+
verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture());
95+
96+
assertThat(batchUpdateCaptor.getAllValues()).hasSize(10)
97+
.allSatisfy(BatchPreparedStatementSetter::getBatchSize)
98+
.satisfies(batches -> {
99+
for (int i = 0; i < 9; i++) {
100+
assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i)
101+
.isEqualTo(1000);
102+
}
103+
assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989);
104+
});
105+
}
106+
60107
}

0 commit comments

Comments
 (0)