From be22692dfa3cbde98cfeb9c5922b1e37bbd8ab10 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Tue, 17 Dec 2024 15:52:18 -0500 Subject: [PATCH] Coherence vector store builder refactoring --- .../CoherenceFilterExpressionConverter.java | 2 +- .../{ => coherence}/CoherenceVectorStore.java | 175 ++++++++++++++++-- ...herenceFilterExpressionConverterTests.java | 2 +- .../CoherenceVectorStoreIT.java | 16 +- 4 files changed, 172 insertions(+), 23 deletions(-) rename vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/{ => coherence}/CoherenceFilterExpressionConverter.java (98%) rename vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/{ => coherence}/CoherenceVectorStore.java (58%) rename vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/{ => coherence}/CoherenceFilterExpressionConverterTests.java (98%) rename vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/{ => coherence}/CoherenceVectorStoreIT.java (96%) diff --git a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceFilterExpressionConverter.java b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceFilterExpressionConverter.java similarity index 98% rename from vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceFilterExpressionConverter.java rename to vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceFilterExpressionConverter.java index 7946d7d4e8f..29f2f1311bc 100644 --- a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceFilterExpressionConverter.java +++ b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceFilterExpressionConverter.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.coherence; import java.util.List; diff --git a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java similarity index 58% rename from vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java rename to vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java index ea065a7d49f..208588bedf9 100644 --- a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java +++ b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.coherence; import java.util.ArrayList; import java.util.HashMap; @@ -39,8 +39,15 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; +import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** *

@@ -66,7 +73,7 @@ * @author Thomas Vitale * @since 1.0.0 */ -public class CoherenceVectorStore implements VectorStore, InitializingBean { +public class CoherenceVectorStore extends AbstractObservationVectorStore implements InitializingBean { public enum IndexType { @@ -112,11 +119,6 @@ public enum DistanceType { public static final CoherenceFilterExpressionConverter FILTER_EXPRESSION_CONVERTER = new CoherenceFilterExpressionConverter(); - /** - * The embedding model to use to create query embedding. - */ - private final EmbeddingModel embeddingModel; - private final int dimensions; private final Session session; @@ -126,45 +128,92 @@ public enum DistanceType { /** * Map name where vectors will be stored. */ - private String mapName = DEFAULT_MAP_NAME; + private String mapName; /** * Distance type to use for computing vector distances. */ - private DistanceType distanceType = DEFAULT_DISTANCE_TYPE; + private DistanceType distanceType; private boolean forcedNormalization; - private IndexType indexType = IndexType.NONE; + private IndexType indexType; + /** + * Creates a new CoherenceVectorStore with minimal configuration. + * @param embeddingModel the embedding model to use + * @param session the Coherence session + * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore(EmbeddingModel embeddingModel, Session session) { - this.embeddingModel = embeddingModel; - this.session = session; - this.dimensions = embeddingModel.dimensions(); + this(builder().embeddingModel(embeddingModel).session(session)); + } + + /** + * Protected constructor that accepts a builder instance. This is the preferred way to + * create new CoherenceVectorStore instances. + * @param builder the configured builder instance + */ + protected CoherenceVectorStore(CoherenceBuilder builder) { + super(builder); + + Assert.notNull(builder.session, "Session must not be null"); + + this.session = builder.session; + this.dimensions = builder.getEmbeddingModel().dimensions(); + this.mapName = builder.mapName; + this.distanceType = builder.distanceType; + this.forcedNormalization = builder.forcedNormalization; + this.indexType = builder.indexType; + } + + /** + * Creates a new builder for configuring and creating CoherenceVectorStore instances. + * @return a new builder instance + */ + public static CoherenceBuilder builder() { + return new CoherenceBuilder(); } + /** + * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore setMapName(String mapName) { this.mapName = mapName; return this; } + /** + * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore setDistanceType(DistanceType distanceType) { this.distanceType = distanceType; return this; } + /** + * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore setIndexType(IndexType indexType) { this.indexType = indexType; return this; } + /** + * @deprecated Since 1.0.0-M5, use {@link #builder()} instead + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public CoherenceVectorStore setForcedNormalization(boolean forcedNormalization) { this.forcedNormalization = forcedNormalization; return this; } @Override - public void add(final List documents) { + public void doAdd(final List documents) { Map chunks = new HashMap<>((int) Math.ceil(documents.size() / 0.75f)); for (Document doc : documents) { var id = toChunkId(doc.getId()); @@ -176,7 +225,7 @@ public void add(final List documents) { } @Override - public Optional delete(final List idList) { + public Optional doDelete(final List idList) { var chunkIds = idList.stream().map(this::toChunkId).toList(); Map results = this.documentChunks.invokeAll(chunkIds, entry -> { if (entry.isPresent()) { @@ -194,7 +243,7 @@ public Optional delete(final List idList) { } @Override - public List similaritySearch(SearchRequest request) { + public List doSimilaritySearch(SearchRequest request) { // From the provided query, generate a vector using the embedding model final Float32Vector vector = toFloat32Vector(this.embeddingModel.embed(request.getQuery())); @@ -265,4 +314,98 @@ String getMapName() { return this.mapName; } + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.NEO4J.value(), operationName) + .withCollectionName(this.mapName) + .withDimensions(this.embeddingModel.dimensions()); + } + + /** + * Builder class for creating {@link CoherenceVectorStore} instances. + *

+ * Provides a fluent API for configuring all aspects of the Coherence vector store, + * including map name, distance type, and indexing options. + * + * @since 1.0.0 + */ + public static class CoherenceBuilder extends AbstractVectorStoreBuilder { + + private Session session; + + private String mapName = DEFAULT_MAP_NAME; + + private DistanceType distanceType = DEFAULT_DISTANCE_TYPE; + + private boolean forcedNormalization = false; + + private IndexType indexType = IndexType.NONE; + + /** + * Sets the Coherence session. + * @param session the session to use + * @return the builder instance + * @throws IllegalArgumentException if session is null + */ + public CoherenceBuilder session(Session session) { + Assert.notNull(session, "Session must not be null"); + this.session = session; + return this; + } + + /** + * Sets the map name for vector storage. + * @param mapName the name of the map to use + * @return the builder instance + */ + public CoherenceBuilder mapName(String mapName) { + if (StringUtils.hasText(mapName)) { + this.mapName = mapName; + } + return this; + } + + /** + * Sets the distance type for vector similarity calculations. + * @param distanceType the distance type to use + * @return the builder instance + * @throws IllegalArgumentException if distanceType is null + */ + public CoherenceBuilder distanceType(DistanceType distanceType) { + Assert.notNull(distanceType, "DistanceType must not be null"); + this.distanceType = distanceType; + return this; + } + + /** + * Sets whether to force vector normalization. + * @param forcedNormalization true to force normalization, false otherwise + * @return the builder instance + */ + public CoherenceBuilder forcedNormalization(boolean forcedNormalization) { + this.forcedNormalization = forcedNormalization; + return this; + } + + /** + * Sets the index type for vector storage. + * @param indexType the index type to use + * @return the builder instance + * @throws IllegalArgumentException if indexType is null + */ + public CoherenceBuilder indexType(IndexType indexType) { + Assert.notNull(indexType, "IndexType must not be null"); + this.indexType = indexType; + return this; + } + + @Override + public CoherenceVectorStore build() { + validate(); + return new CoherenceVectorStore(this); + } + + } + } diff --git a/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceFilterExpressionConverterTests.java b/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/coherence/CoherenceFilterExpressionConverterTests.java similarity index 98% rename from vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceFilterExpressionConverterTests.java rename to vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/coherence/CoherenceFilterExpressionConverterTests.java index 47ac89259d0..b1487f3e2ec 100644 --- a/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/coherence/CoherenceFilterExpressionConverterTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.coherence; import com.tangosol.util.Filters; import com.tangosol.util.ValueExtractor; diff --git a/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java b/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStoreIT.java similarity index 96% rename from vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java rename to vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStoreIT.java index 776d449b518..6e8734b8c7d 100644 --- a/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java +++ b/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStoreIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vectorstore; +package org.springframework.ai.vectorstore.coherence; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -50,6 +50,8 @@ import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -309,10 +311,14 @@ public static class TestClient { @Bean public VectorStore vectorStore(EmbeddingModel embeddingModel, Session session) { - return new CoherenceVectorStore(embeddingModel, session).setDistanceType(this.distanceType) - .setIndexType(this.indexType) - .setForcedNormalization(this.distanceType == CoherenceVectorStore.DistanceType.COSINE - || this.distanceType == CoherenceVectorStore.DistanceType.IP); + return CoherenceVectorStore.builder() + .embeddingModel(embeddingModel) + .session(session) + .distanceType(this.distanceType) + .indexType(this.indexType) + .forcedNormalization(this.distanceType == CoherenceVectorStore.DistanceType.COSINE + || this.distanceType == CoherenceVectorStore.DistanceType.IP) + .build(); } @Bean