From f602b873fb68018182345e9550c7935b8c43b0ce Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Mon, 9 Dec 2024 19:48:51 -0500 Subject: [PATCH 1/2] Redis vector store builder refactoring --- .../RedisVectorStoreAutoConfiguration.java | 21 +- .../RedisFilterExpressionConverter.java | 4 +- .../vectorstore/RedisVectorStore.java | 249 +++++++++++++++--- .../RedisFilterExpressionConverterTests.java | 38 ++- .../vectorstore/RedisVectorStoreIT.java | 22 +- .../RedisVectorStoreObservationIT.java | 26 +- 6 files changed, 265 insertions(+), 95 deletions(-) rename vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/{ => redis}/vectorstore/RedisFilterExpressionConverter.java (97%) rename vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/{ => redis}/vectorstore/RedisVectorStore.java (68%) rename vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/{ => redis}/vectorstore/RedisFilterExpressionConverterTests.java (71%) rename vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/{ => redis}/vectorstore/RedisVectorStoreIT.java (93%) rename vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/{ => redis}/vectorstore/RedisVectorStoreObservationIT.java (90%) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java index 6a252bbcced..3476a8bb233 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java @@ -22,8 +22,8 @@ import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.TokenCountBatchingStrategy; -import org.springframework.ai.vectorstore.RedisVectorStore; -import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; +import org.springframework.ai.redis.vectorstore.RedisVectorStore; +import org.springframework.ai.redis.vectorstore.RedisVectorStore.RedisVectorStoreConfig; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -61,15 +61,16 @@ public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorSt ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - var config = RedisVectorStoreConfig.builder() - .withIndexName(properties.getIndex()) - .withPrefix(properties.getPrefix()) + return RedisVectorStore + .builder(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort())) + .embeddingModel(embeddingModel) + .initializeSchema(properties.isInitializeSchema()) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .customObservationConvention(customObservationConvention.getIfAvailable(() -> null)) + .batchingStrategy(batchingStrategy) + .indexName(properties.getIndex()) + .prefix(properties.getPrefix()) .build(); - - return new RedisVectorStore(config, embeddingModel, - new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), - properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), - customObservationConvention.getIfAvailable(() -> null), batchingStrategy); } } diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisFilterExpressionConverter.java similarity index 97% rename from vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java rename to vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisFilterExpressionConverter.java index 4536eda085f..379ebd731b8 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverter.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisFilterExpressionConverter.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.redis.vectorstore; import java.text.MessageFormat; import java.util.List; @@ -22,7 +22,7 @@ import java.util.function.Function; import java.util.stream.Collectors; -import org.springframework.ai.vectorstore.RedisVectorStore.MetadataField; +import org.springframework.ai.redis.vectorstore.RedisVectorStore.MetadataField; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisVectorStore.java similarity index 68% rename from vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java rename to vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisVectorStore.java index 8415f8dfa0f..0b7fe3a3034 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisVectorStore.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.redis.vectorstore; import java.text.MessageFormat; import java.util.ArrayList; @@ -54,6 +54,9 @@ import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; +import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; @@ -61,6 +64,7 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * The RedisVectorStore is for managing and querying vector data in a Redis database. It @@ -119,40 +123,63 @@ public class RedisVectorStore extends AbstractObservationVectorStore implements private static final String DEFAULT_DISTANCE_METRIC = "COSINE"; + private final JedisPooled jedis; + private final boolean initializeSchema; - private final JedisPooled jedis; + private final String indexName; + + private final String prefix; + + private final String contentFieldName; - private final EmbeddingModel embeddingModel; + private final String embeddingFieldName; - private final RedisVectorStoreConfig config; + private final Algorithm vectorAlgorithm; + + private final List metadataFields; private final BatchingStrategy batchingStrategy; - private FilterExpressionConverter filterExpressionConverter; + private final FilterExpressionConverter filterExpressionConverter; + @Deprecated(since = "1.0.0-M5", forRemoval = true) public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema) { - this(config, embeddingModel, jedis, initializeSchema, ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - super(observationRegistry, customObservationConvention); - - Assert.notNull(config, "Config must not be null"); - Assert.notNull(embeddingModel, "Embedding model must not be null"); - this.initializeSchema = initializeSchema; + this(builder(jedis).embeddingModel(embeddingModel) + .indexName(config.indexName) + .prefix(config.prefix) + .contentFieldName(config.contentFieldName) + .embeddingFieldName(config.embeddingFieldName) + .vectorAlgorithm(config.vectorAlgorithm) + .metadataFields(config.metadataFields) + .initializeSchema(initializeSchema) + .observationRegistry(observationRegistry) + .customObservationConvention(customObservationConvention) + .batchingStrategy(batchingStrategy)); + } - this.jedis = jedis; - this.embeddingModel = embeddingModel; - this.config = config; - this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields); - this.batchingStrategy = batchingStrategy; + private RedisVectorStore(RedisBuilder builder) { + super(builder); + this.jedis = builder.jedis; + this.indexName = builder.indexName; + this.prefix = builder.prefix; + this.contentFieldName = builder.contentFieldName; + this.embeddingFieldName = builder.embeddingFieldName; + this.vectorAlgorithm = builder.vectorAlgorithm; + this.metadataFields = builder.metadataFields; + this.initializeSchema = builder.initializeSchema; + this.batchingStrategy = builder.batchingStrategy; + this.filterExpressionConverter = new RedisFilterExpressionConverter(this.metadataFields); } public JedisPooled getJedis() { @@ -168,8 +195,8 @@ public void doAdd(List documents) { for (Document document : documents) { var fields = new HashMap(); - fields.put(this.config.embeddingFieldName, embeddings.get(documents.indexOf(document))); - fields.put(this.config.contentFieldName, document.getContent()); + fields.put(this.embeddingFieldName, embeddings.get(documents.indexOf(document))); + fields.put(this.contentFieldName, document.getContent()); fields.putAll(document.getMetadata()); pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields); } @@ -186,7 +213,7 @@ public void doAdd(List documents) { } private String key(String id) { - return this.config.prefix + id; + return this.prefix + id; } @Override @@ -216,13 +243,13 @@ public List doSimilaritySearch(SearchRequest request) { String filter = nativeExpressionFilter(request); - String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName, + String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.embeddingFieldName, EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME); List returnFields = new ArrayList<>(); - this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add); - returnFields.add(this.config.embeddingFieldName); - returnFields.add(this.config.contentFieldName); + this.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add); + returnFields.add(this.embeddingFieldName); + returnFields.add(this.contentFieldName); returnFields.add(DISTANCE_FIELD_NAME); var embedding = this.embeddingModel.embed(request.getQuery()); Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding)) @@ -231,7 +258,7 @@ public List doSimilaritySearch(SearchRequest request) { .limit(0, request.getTopK()) .dialect(2); - SearchResult result = this.jedis.ftSearch(this.config.indexName, query); + SearchResult result = this.jedis.ftSearch(this.indexName, query); return result.getDocuments() .stream() .filter(d -> similarityScore(d) >= request.getSimilarityThreshold()) @@ -240,9 +267,9 @@ public List doSimilaritySearch(SearchRequest request) { } private Document toDocument(redis.clients.jedis.search.Document doc) { - var id = doc.getId().substring(this.config.prefix.length()); - var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) : ""; - Map metadata = this.config.metadataFields.stream() + var id = doc.getId().substring(this.prefix.length()); + var content = doc.hasProperty(this.contentFieldName) ? doc.getString(this.contentFieldName) : ""; + Map metadata = this.metadataFields.stream() .map(MetadataField::name) .filter(doc::hasProperty) .collect(Collectors.toMap(Function.identity(), doc::getString)); @@ -272,12 +299,12 @@ public void afterPropertiesSet() { } // If index already exists don't do anything - if (this.jedis.ftList().contains(this.config.indexName)) { + if (this.jedis.ftList().contains(this.indexName)) { return; } - String response = this.jedis.ftCreate(this.config.indexName, - FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields()); + String response = this.jedis.ftCreate(this.indexName, + FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.prefix), schemaFields()); if (!RESPONSE_OK.test(response)) { String message = MessageFormat.format("Could not create index: {0}", response); throw new RuntimeException(message); @@ -290,16 +317,16 @@ private Iterable schemaFields() { vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC); vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32); List fields = new ArrayList<>(); - fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0)); + fields.add(TextField.of(jsonPath(this.contentFieldName)).as(this.contentFieldName).weight(1.0)); fields.add(VectorField.builder() - .fieldName(jsonPath(this.config.embeddingFieldName)) + .fieldName(jsonPath(this.embeddingFieldName)) .algorithm(vectorAlgorithm()) .attributes(vectorAttrs) - .as(this.config.embeddingFieldName) + .as(this.embeddingFieldName) .build()); - if (!CollectionUtils.isEmpty(this.config.metadataFields)) { - for (MetadataField field : this.config.metadataFields) { + if (!CollectionUtils.isEmpty(this.metadataFields)) { + for (MetadataField field : this.metadataFields) { fields.add(schemaField(field)); } } @@ -318,7 +345,7 @@ private SchemaField schemaField(MetadataField field) { } private VectorAlgorithm vectorAlgorithm() { - if (this.config.vectorAlgorithm == Algorithm.HSNW) { + if (this.vectorAlgorithm == Algorithm.HSNW) { return VectorAlgorithm.HNSW; } return VectorAlgorithm.FLAT; @@ -332,9 +359,9 @@ private String jsonPath(String field) { public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { return VectorStoreObservationContext.builder(VectorStoreProvider.REDIS.value(), operationName) - .withCollectionName(this.config.indexName) + .withCollectionName(this.indexName) .withDimensions(this.embeddingModel.dimensions()) - .withFieldName(this.config.embeddingFieldName) + .withFieldName(this.embeddingFieldName) .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); } @@ -361,9 +388,150 @@ public static MetadataField tag(String name) { } + public static RedisBuilder builder(JedisPooled jedis) { + return new RedisBuilder(jedis); + } + + public static class RedisBuilder extends AbstractVectorStoreBuilder { + + private final JedisPooled jedis; + + private String indexName = DEFAULT_INDEX_NAME; + + private String prefix = DEFAULT_PREFIX; + + private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME; + + private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME; + + private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; + + private List metadataFields = new ArrayList<>(); + + private boolean initializeSchema = false; + + private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + + public RedisBuilder(JedisPooled jedis) { + Assert.notNull(jedis, "JedisPooled must not be null"); + this.jedis = jedis; + } + + /** + * Sets the Redis index name. + * @param indexName the index name to use + * @return the builder instance + */ + public RedisBuilder indexName(String indexName) { + if (StringUtils.hasText(indexName)) { + this.indexName = indexName; + } + return this; + } + + /** + * Sets the Redis key prefix (default: "embedding:"). + * @param prefix the prefix to use + * @return the builder instance + */ + public RedisBuilder prefix(String prefix) { + if (StringUtils.hasText(prefix)) { + this.prefix = prefix; + } + return this; + } + + /** + * Sets the Redis content field name. + * @param fieldName the content field name to use + * @return the builder instance + */ + public RedisBuilder contentFieldName(String fieldName) { + if (StringUtils.hasText(fieldName)) { + this.contentFieldName = fieldName; + } + return this; + } + + /** + * Sets the Redis embedding field name. + * @param fieldName the embedding field name to use + * @return the builder instance + */ + public RedisBuilder embeddingFieldName(String fieldName) { + if (StringUtils.hasText(fieldName)) { + this.embeddingFieldName = fieldName; + } + return this; + } + + /** + * Sets the Redis vector algorithm. + * @param algorithm the vector algorithm to use + * @return the builder instance + */ + public RedisBuilder vectorAlgorithm(Algorithm algorithm) { + if (algorithm != null) { + this.vectorAlgorithm = algorithm; + } + return this; + } + + /** + * Sets the metadata fields. + * @param fields the metadata fields to include + * @return the builder instance + */ + public RedisBuilder metadataFields(MetadataField... fields) { + return metadataFields(Arrays.asList(fields)); + } + + /** + * Sets the metadata fields. + * @param fields the list of metadata fields to include + * @return the builder instance + */ + public RedisBuilder metadataFields(List fields) { + if (fields != null && !fields.isEmpty()) { + this.metadataFields = new ArrayList<>(fields); + } + return this; + } + + /** + * Sets whether to initialize the schema. + * @param initializeSchema true to initialize schema, false otherwise + * @return the builder instance + */ + public RedisBuilder initializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + return this; + } + + /** + * Sets the batching strategy. + * @param batchingStrategy the strategy to use + * @return the builder instance + * @throws IllegalArgumentException if batchingStrategy is null + */ + public RedisBuilder batchingStrategy(BatchingStrategy batchingStrategy) { + Assert.notNull(batchingStrategy, "BatchingStrategy must not be null"); + this.batchingStrategy = batchingStrategy; + return this; + } + + @Override + public RedisVectorStore build() { + validate(); + return new RedisVectorStore(this); + } + + } + /** * Configuration for the Redis vector store. */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public static final class RedisVectorStoreConfig { private final String indexName; @@ -395,19 +563,20 @@ private RedisVectorStoreConfig(Builder builder) { * Start building a new configuration. * @return The entry point for creating a new configuration. */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public static Builder builder() { - return new Builder(); } /** * {@return the default config} */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public static RedisVectorStoreConfig defaultConfig() { - return builder().build(); } + @Deprecated(since = "1.0.0-M5", forRemoval = true) public static final class Builder { private String indexName = DEFAULT_INDEX_NAME; diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisFilterExpressionConverterTests.java similarity index 71% rename from vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java rename to vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisFilterExpressionConverterTests.java index 07e8ef0e928..e1ae8c43e91 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisFilterExpressionConverterTests.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.redis.vectorstore; import java.util.Arrays; import java.util.List; import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.RedisVectorStore.MetadataField; +import org.springframework.ai.redis.vectorstore.RedisVectorStore.MetadataField; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; @@ -49,7 +49,7 @@ private static RedisFilterExpressionConverter converter(MetadataField... fields) @Test void testEQ() { // country == "BG" - String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country")) + String vectorExpr = converter(RedisVectorStore.MetadataField.tag("country")) .convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("@country:{BG}"); } @@ -57,8 +57,8 @@ void testEQ() { @Test void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("genre"), - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("year")) + String vectorExpr = converter(RedisVectorStore.MetadataField.tag("genre"), + RedisVectorStore.MetadataField.numeric("year")) .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr).isEqualTo("@genre:{drama} @year:[2020 inf]"); @@ -67,18 +67,16 @@ void tesEqAndGte() { @Test void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("genre")) - .convertExpression( - new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); + String vectorExpr = converter(RedisVectorStore.MetadataField.tag("genre")).convertExpression( + new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("@genre:{comedy | documentary | drama}"); } @Test void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("year"), - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country"), - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("city")) + String vectorExpr = converter(RedisVectorStore.MetadataField.numeric("year"), + RedisVectorStore.MetadataField.tag("country"), RedisVectorStore.MetadataField.tag("city")) .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Group(new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia")))))); @@ -88,9 +86,8 @@ void testNe() { @Test void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("year"), - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country"), - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("city")) + String vectorExpr = converter(RedisVectorStore.MetadataField.numeric("year"), + RedisVectorStore.MetadataField.tag("country"), RedisVectorStore.MetadataField.tag("city")) .convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), @@ -101,9 +98,8 @@ void testGroup() { @Test void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("year"), - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country"), - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("isOpen")) + String vectorExpr = converter(RedisVectorStore.MetadataField.numeric("year"), + RedisVectorStore.MetadataField.tag("country"), RedisVectorStore.MetadataField.tag("isOpen")) .convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), @@ -115,8 +111,7 @@ void tesBoolean() { @Test void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter( - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("temperature")) + String vectorExpr = converter(RedisVectorStore.MetadataField.numeric("temperature")) .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -125,12 +120,11 @@ void testDecimal() { @Test void testComplexIdentifiers() { - String vectorExpr = converter( - org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country 1 2 3")) + String vectorExpr = converter(RedisVectorStore.MetadataField.tag("country 1 2 3")) .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("@\"country 1 2 3\":{BG}"); - vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country 1 2 3")) + vectorExpr = converter(RedisVectorStore.MetadataField.tag("country 1 2 3")) .convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("@'country 1 2 3':{BG}"); } diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisVectorStoreIT.java similarity index 93% rename from vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java rename to vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisVectorStoreIT.java index 6c389b672d6..983bc6d8b27 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisVectorStoreIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.redis.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -34,8 +34,10 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; -import org.springframework.ai.vectorstore.RedisVectorStore.MetadataField; -import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; +import org.springframework.ai.redis.vectorstore.RedisVectorStore.MetadataField; +import org.springframework.ai.redis.vectorstore.RedisVectorStore.RedisVectorStoreConfig; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -255,13 +257,13 @@ public static class TestApplication { @Bean public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, JedisConnectionFactory jedisConnectionFactory) { - return new RedisVectorStore( - RedisVectorStoreConfig.builder() - .withMetadataFields(MetadataField.tag("meta1"), MetadataField.tag("meta2"), - MetadataField.tag("country"), MetadataField.numeric("year")) - .build(), - embeddingModel, - new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), true); + return RedisVectorStore + .builder(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort())) + .embeddingModel(embeddingModel) + .metadataFields(MetadataField.tag("meta1"), MetadataField.tag("meta2"), MetadataField.tag("country"), + MetadataField.numeric("year")) + .initializeSchema(true) + .build(); } @Bean diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisVectorStoreObservationIT.java similarity index 90% rename from vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java rename to vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisVectorStoreObservationIT.java index 2d2ed538cb2..0e6e99e298b 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/redis/vectorstore/RedisVectorStoreObservationIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.redis.vectorstore; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -38,8 +38,10 @@ import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.transformers.TransformersEmbeddingModel; -import org.springframework.ai.vectorstore.RedisVectorStore.MetadataField; -import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; +import org.springframework.ai.redis.vectorstore.RedisVectorStore.MetadataField; +import org.springframework.ai.redis.vectorstore.RedisVectorStore.RedisVectorStoreConfig; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames; @@ -175,14 +177,16 @@ public TestObservationRegistry observationRegistry() { @Bean public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, JedisConnectionFactory jedisConnectionFactory, ObservationRegistry observationRegistry) { - return new RedisVectorStore( - RedisVectorStoreConfig.builder() - .withMetadataFields(MetadataField.tag("meta1"), MetadataField.tag("meta2"), - MetadataField.tag("country"), MetadataField.numeric("year")) - .build(), - embeddingModel, - new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), true, - observationRegistry, null, new TokenCountBatchingStrategy()); + return RedisVectorStore + .builder(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort())) + .embeddingModel(embeddingModel) + .observationRegistry(observationRegistry) + .customObservationConvention(null) + .initializeSchema(true) + .batchingStrategy(new TokenCountBatchingStrategy()) + .metadataFields(MetadataField.tag("meta1"), MetadataField.tag("meta2"), MetadataField.tag("country"), + MetadataField.numeric("year")) + .build(); } @Bean From 8a17f683b7fef04b47d156c4f5c199e218b38e48 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Tue, 10 Dec 2024 20:56:07 -0500 Subject: [PATCH 2/2] Redis vector store builder refactoring --- .../RedisVectorStoreAutoConfiguration.java | 4 ++-- .../ai/redis/vectorstore/RedisVectorStore.java | 17 +++++++++++------ .../redis/vectorstore/RedisVectorStoreIT.java | 4 ++-- .../RedisVectorStoreObservationIT.java | 4 ++-- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java index 3476a8bb233..6b02e62a5c1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/redis/RedisVectorStoreAutoConfiguration.java @@ -61,8 +61,8 @@ public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorSt ObjectProvider customObservationConvention, BatchingStrategy batchingStrategy) { - return RedisVectorStore - .builder(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort())) + return RedisVectorStore.builder() + .jedis(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort())) .embeddingModel(embeddingModel) .initializeSchema(properties.isInitializeSchema()) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisVectorStore.java index 0b7fe3a3034..182df2b8892 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/redis/vectorstore/RedisVectorStore.java @@ -155,7 +155,8 @@ public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingM boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) { - this(builder(jedis).embeddingModel(embeddingModel) + this(builder().jedis(jedis) + .embeddingModel(embeddingModel) .indexName(config.indexName) .prefix(config.prefix) .contentFieldName(config.contentFieldName) @@ -168,8 +169,11 @@ public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingM .batchingStrategy(batchingStrategy)); } - private RedisVectorStore(RedisBuilder builder) { + protected RedisVectorStore(RedisBuilder builder) { super(builder); + + Assert.notNull(builder.jedis, "JedisPooled must not be null"); + this.jedis = builder.jedis; this.indexName = builder.indexName; this.prefix = builder.prefix; @@ -388,13 +392,13 @@ public static MetadataField tag(String name) { } - public static RedisBuilder builder(JedisPooled jedis) { - return new RedisBuilder(jedis); + public static RedisBuilder builder() { + return new RedisBuilder(); } public static class RedisBuilder extends AbstractVectorStoreBuilder { - private final JedisPooled jedis; + private JedisPooled jedis; private String indexName = DEFAULT_INDEX_NAME; @@ -412,9 +416,10 @@ public static class RedisBuilder extends AbstractVectorStoreBuilder