Skip to content

Commit 4dbe734

Browse files
jitokimilayaperumalg
authored andcommitted
Refactor ID handling for different IdType formats
- Add handling for UUID, TEXT, INTEGER, SERIAL, BIGSERIAL formats in `convertIdToPgType` function. - Implemented type conversion logic based on the IdType value (UUID, TEXT, INTEGER, SERIAL, BIGSERIAL). - Add unit tests to validate correct conversion for UUID and non-UUID IdType formats. - `testToPgTypeWithUuidIdType`: Validates UUID handling. - `testToPgTypeWithNonUuidIdType`: Validates handling for non-UUID IdTypes. Signed-off-by: jitokim <[email protected]>
1 parent f40945b commit 4dbe734

File tree

2 files changed

+101
-6
lines changed

2 files changed

+101
-6
lines changed

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@
3535

3636
import org.springframework.ai.document.Document;
3737
import org.springframework.ai.document.DocumentMetadata;
38-
import org.springframework.ai.embedding.BatchingStrategy;
3938
import org.springframework.ai.embedding.EmbeddingModel;
4039
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
41-
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
4240
import org.springframework.ai.observation.conventions.VectorStoreProvider;
4341
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
4442
import org.springframework.ai.util.JacksonUtils;
@@ -153,6 +151,7 @@
153151
* @author Thomas Vitale
154152
* @author Soby Chacko
155153
* @author Sebastien Deleuze
154+
* @author Jihoon Kim
156155
* @since 1.0.0
157156
*/
158157
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {
@@ -163,6 +162,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
163162

164163
public static final String DEFAULT_TABLE_NAME = "vector_store";
165164

165+
public static final PgIdType DEFAULT_ID_TYPE = PgIdType.UUID;
166+
166167
public static final String DEFAULT_VECTOR_INDEX_NAME = "spring_ai_vector_index";
167168

168169
public static final String DEFAULT_SCHEMA_NAME = "public";
@@ -188,6 +189,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
188189

189190
private final String schemaName;
190191

192+
private final PgIdType idType;
193+
191194
private final boolean schemaValidation;
192195

193196
private final boolean initializeSchema;
@@ -225,6 +228,7 @@ protected PgVectorStore(PgVectorStoreBuilder builder) {
225228
: this.vectorTableName + "_index";
226229

227230
this.schemaName = builder.schemaName;
231+
this.idType = builder.idType;
228232
this.schemaValidation = builder.vectorTableValidationsEnabled;
229233

230234
this.jdbcTemplate = builder.jdbcTemplate;
@@ -273,13 +277,13 @@ private void insertOrUpdateBatch(List<Document> batch, List<Document> documents,
273277
public void setValues(PreparedStatement ps, int i) throws SQLException {
274278

275279
var document = batch.get(i);
280+
var id = convertIdToPgType(document.getId());
276281
var content = document.getText();
277282
var json = toJson(document.getMetadata());
278283
var embedding = embeddings.get(documents.indexOf(document));
279284
var pGvector = new PGvector(embedding);
280285

281-
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
282-
UUID.fromString(document.getId()));
286+
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, id);
283287
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
284288
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
285289
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
@@ -304,6 +308,19 @@ private String toJson(Map<String, Object> map) {
304308
}
305309
}
306310

311+
private Object convertIdToPgType(String id) {
312+
if (this.initializeSchema) {
313+
return UUID.fromString(id);
314+
}
315+
316+
return switch (getIdType()) {
317+
case UUID -> UUID.fromString(id);
318+
case TEXT -> id;
319+
case INTEGER, SERIAL -> Integer.valueOf(id);
320+
case BIGSERIAL -> Long.valueOf(id);
321+
};
322+
}
323+
307324
@Override
308325
public Optional<Boolean> doDelete(List<String> idList) {
309326
int updateCount = 0;
@@ -429,6 +446,10 @@ private String getFullyQualifiedTableName() {
429446
return this.schemaName + "." + this.vectorTableName;
430447
}
431448

449+
private PgIdType getIdType() {
450+
return this.idType;
451+
}
452+
432453
private String getVectorTableName() {
433454
return this.vectorTableName;
434455
}
@@ -513,6 +534,12 @@ public enum PgIndexType {
513534

514535
}
515536

537+
public enum PgIdType {
538+
539+
UUID, TEXT, INTEGER, SERIAL, BIGSERIAL
540+
541+
}
542+
516543
/**
517544
* Defaults to CosineDistance. But if vectors are normalized to length 1 (like OpenAI
518545
* embeddings), use inner product (NegativeInnerProduct) for best performance.
@@ -608,6 +635,8 @@ public static final class PgVectorStoreBuilder extends AbstractVectorStoreBuilde
608635

609636
private String vectorTableName = PgVectorStore.DEFAULT_TABLE_NAME;
610637

638+
private PgIdType idType = PgVectorStore.DEFAULT_ID_TYPE;
639+
611640
private boolean vectorTableValidationsEnabled = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;
612641

613642
private int dimensions = PgVectorStore.INVALID_EMBEDDING_DIMENSION;
@@ -638,6 +667,11 @@ public PgVectorStoreBuilder vectorTableName(String vectorTableName) {
638667
return this;
639668
}
640669

670+
public PgVectorStoreBuilder idType(PgIdType idType) {
671+
this.idType = idType;
672+
return this;
673+
}
674+
641675
public PgVectorStoreBuilder vectorTableValidationsEnabled(boolean vectorTableValidationsEnabled) {
642676
this.vectorTableValidationsEnabled = vectorTableValidationsEnabled;
643677
return this;

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.io.IOException;
2020
import java.nio.charset.StandardCharsets;
2121
import java.util.Collections;
22+
import java.util.HashMap;
2223
import java.util.Iterator;
2324
import java.util.List;
2425
import java.util.Map;
@@ -42,14 +43,16 @@
4243

4344
import org.springframework.ai.document.Document;
4445
import org.springframework.ai.document.DocumentMetadata;
46+
import org.springframework.ai.document.id.RandomIdGenerator;
4547
import org.springframework.ai.embedding.EmbeddingModel;
4648
import org.springframework.ai.openai.OpenAiEmbeddingModel;
4749
import org.springframework.ai.openai.api.OpenAiApi;
50+
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIdType;
51+
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType;
4852
import org.springframework.ai.vectorstore.SearchRequest;
4953
import org.springframework.ai.vectorstore.VectorStore;
5054
import org.springframework.ai.vectorstore.filter.Filter;
5155
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser.FilterExpressionParseException;
52-
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType;
5356
import org.springframework.beans.factory.annotation.Value;
5457
import org.springframework.boot.SpringBootConfiguration;
5558
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
@@ -70,6 +73,7 @@
7073
* @author Muthukumaran Navaneethakrishnan
7174
* @author Christian Tzolov
7275
* @author Thomas Vitale
76+
* @author Jihoon Kim
7377
*/
7478
@Testcontainers
7579
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@@ -106,6 +110,27 @@ public static String getText(String uri) {
106110
}
107111
}
108112

113+
private static void initSchema(ApplicationContext context) {
114+
PgVectorStore vectorStore = context.getBean(PgVectorStore.class);
115+
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
116+
// Enable the PGVector, JSONB and UUID support.
117+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
118+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
119+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
120+
121+
jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", PgVectorStore.DEFAULT_SCHEMA_NAME));
122+
123+
jdbcTemplate.execute(String.format("""
124+
CREATE TABLE IF NOT EXISTS %s.%s (
125+
id text PRIMARY KEY,
126+
content text,
127+
metadata json,
128+
embedding vector(%d)
129+
)
130+
""", PgVectorStore.DEFAULT_SCHEMA_NAME, PgVectorStore.DEFAULT_TABLE_NAME,
131+
vectorStore.embeddingDimensions()));
132+
}
133+
109134
private static void dropTable(ApplicationContext context) {
110135
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
111136
jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
@@ -169,6 +194,35 @@ public void addAndSearch(String distanceType) {
169194
});
170195
}
171196

197+
@Test
198+
public void testToPgTypeWithUuidIdType() {
199+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
200+
.run(context -> {
201+
202+
VectorStore vectorStore = context.getBean(VectorStore.class);
203+
204+
vectorStore.add(List.of(new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>())));
205+
206+
dropTable(context);
207+
});
208+
}
209+
210+
@Test
211+
public void testToPgTypeWithNonUuidIdType() {
212+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
213+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
214+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
215+
.run(context -> {
216+
217+
VectorStore vectorStore = context.getBean(VectorStore.class);
218+
initSchema(context);
219+
220+
vectorStore.add(List.of(new Document("NOT_UUID", "TEXT", new HashMap<>())));
221+
222+
dropTable(context);
223+
});
224+
}
225+
172226
@ParameterizedTest(name = "Filter expression {0} should return {1} records ")
173227
@MethodSource("provideFilters")
174228
public void searchWithInFilter(String expression, Integer expectedRecords) {
@@ -498,12 +552,19 @@ public static class TestApplication {
498552
@Value("${test.spring.ai.vectorstore.pgvector.distanceType}")
499553
PgVectorStore.PgDistanceType distanceType;
500554

555+
@Value("${test.spring.ai.vectorstore.pgvector.initializeSchema:true}")
556+
boolean initializeSchema;
557+
558+
@Value("${test.spring.ai.vectorstore.pgvector.idType:UUID}")
559+
PgIdType idType;
560+
501561
@Bean
502562
public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
503563
return PgVectorStore.builder(jdbcTemplate, embeddingModel)
504564
.dimensions(PgVectorStore.INVALID_EMBEDDING_DIMENSION)
565+
.idType(idType)
505566
.distanceType(this.distanceType)
506-
.initializeSchema(true)
567+
.initializeSchema(initializeSchema)
507568
.indexType(PgIndexType.HNSW)
508569
.removeExistingVectorStoreTable(true)
509570
.build();

0 commit comments

Comments
 (0)