Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.vectorstore.coherence;

import java.util.List;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.vectorstore.coherence;

import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -39,8 +39,15 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* <p>
Expand All @@ -66,7 +73,7 @@
* @author Thomas Vitale
* @since 1.0.0
*/
public class CoherenceVectorStore implements VectorStore, InitializingBean {
public class CoherenceVectorStore extends AbstractObservationVectorStore implements InitializingBean {

public enum IndexType {

Expand Down Expand Up @@ -112,11 +119,6 @@ public enum DistanceType {

public static final CoherenceFilterExpressionConverter FILTER_EXPRESSION_CONVERTER = new CoherenceFilterExpressionConverter();

/**
* The embedding model to use to create query embedding.
*/
private final EmbeddingModel embeddingModel;

private final int dimensions;

private final Session session;
Expand All @@ -126,45 +128,92 @@ public enum DistanceType {
/**
* Map name where vectors will be stored.
*/
private String mapName = DEFAULT_MAP_NAME;
private String mapName;

/**
* Distance type to use for computing vector distances.
*/
private DistanceType distanceType = DEFAULT_DISTANCE_TYPE;
private DistanceType distanceType;

private boolean forcedNormalization;

private IndexType indexType = IndexType.NONE;
private IndexType indexType;

/**
* Creates a new CoherenceVectorStore with minimal configuration.
* @param embeddingModel the embedding model to use
* @param session the Coherence session
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public CoherenceVectorStore(EmbeddingModel embeddingModel, Session session) {
this.embeddingModel = embeddingModel;
this.session = session;
this.dimensions = embeddingModel.dimensions();
this(builder().embeddingModel(embeddingModel).session(session));
}

/**
* Protected constructor that accepts a builder instance. This is the preferred way to
* create new CoherenceVectorStore instances.
* @param builder the configured builder instance
*/
protected CoherenceVectorStore(CoherenceBuilder builder) {
super(builder);

Assert.notNull(builder.session, "Session must not be null");

this.session = builder.session;
this.dimensions = builder.getEmbeddingModel().dimensions();
this.mapName = builder.mapName;
this.distanceType = builder.distanceType;
this.forcedNormalization = builder.forcedNormalization;
this.indexType = builder.indexType;
}

/**
* Creates a new builder for configuring and creating CoherenceVectorStore instances.
* @return a new builder instance
*/
public static CoherenceBuilder builder() {
return new CoherenceBuilder();
}

/**
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public CoherenceVectorStore setMapName(String mapName) {
this.mapName = mapName;
return this;
}

/**
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public CoherenceVectorStore setDistanceType(DistanceType distanceType) {
this.distanceType = distanceType;
return this;
}

/**
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public CoherenceVectorStore setIndexType(IndexType indexType) {
this.indexType = indexType;
return this;
}

/**
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public CoherenceVectorStore setForcedNormalization(boolean forcedNormalization) {
this.forcedNormalization = forcedNormalization;
return this;
}

@Override
public void add(final List<Document> documents) {
public void doAdd(final List<Document> documents) {
Map<DocumentChunk.Id, DocumentChunk> chunks = new HashMap<>((int) Math.ceil(documents.size() / 0.75f));
for (Document doc : documents) {
var id = toChunkId(doc.getId());
Expand All @@ -176,7 +225,7 @@ public void add(final List<Document> documents) {
}

@Override
public Optional<Boolean> delete(final List<String> idList) {
public Optional<Boolean> doDelete(final List<String> idList) {
var chunkIds = idList.stream().map(this::toChunkId).toList();
Map<DocumentChunk.Id, Boolean> results = this.documentChunks.invokeAll(chunkIds, entry -> {
if (entry.isPresent()) {
Expand All @@ -194,7 +243,7 @@ public Optional<Boolean> delete(final List<String> idList) {
}

@Override
public List<Document> similaritySearch(SearchRequest request) {
public List<Document> doSimilaritySearch(SearchRequest request) {
// From the provided query, generate a vector using the embedding model
final Float32Vector vector = toFloat32Vector(this.embeddingModel.embed(request.getQuery()));

Expand Down Expand Up @@ -265,4 +314,98 @@ String getMapName() {
return this.mapName;
}

@Override
public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {

return VectorStoreObservationContext.builder(VectorStoreProvider.NEO4J.value(), operationName)
.withCollectionName(this.mapName)
.withDimensions(this.embeddingModel.dimensions());
}

/**
* Builder class for creating {@link CoherenceVectorStore} instances.
* <p>
* Provides a fluent API for configuring all aspects of the Coherence vector store,
* including map name, distance type, and indexing options.
*
* @since 1.0.0
*/
public static class CoherenceBuilder extends AbstractVectorStoreBuilder<CoherenceBuilder> {

private Session session;

private String mapName = DEFAULT_MAP_NAME;

private DistanceType distanceType = DEFAULT_DISTANCE_TYPE;

private boolean forcedNormalization = false;

private IndexType indexType = IndexType.NONE;

/**
* Sets the Coherence session.
* @param session the session to use
* @return the builder instance
* @throws IllegalArgumentException if session is null
*/
public CoherenceBuilder session(Session session) {
Assert.notNull(session, "Session must not be null");
this.session = session;
return this;
}

/**
* Sets the map name for vector storage.
* @param mapName the name of the map to use
* @return the builder instance
*/
public CoherenceBuilder mapName(String mapName) {
if (StringUtils.hasText(mapName)) {
this.mapName = mapName;
}
return this;
}

/**
* Sets the distance type for vector similarity calculations.
* @param distanceType the distance type to use
* @return the builder instance
* @throws IllegalArgumentException if distanceType is null
*/
public CoherenceBuilder distanceType(DistanceType distanceType) {
Assert.notNull(distanceType, "DistanceType must not be null");
this.distanceType = distanceType;
return this;
}

/**
* Sets whether to force vector normalization.
* @param forcedNormalization true to force normalization, false otherwise
* @return the builder instance
*/
public CoherenceBuilder forcedNormalization(boolean forcedNormalization) {
this.forcedNormalization = forcedNormalization;
return this;
}

/**
* Sets the index type for vector storage.
* @param indexType the index type to use
* @return the builder instance
* @throws IllegalArgumentException if indexType is null
*/
public CoherenceBuilder indexType(IndexType indexType) {
Assert.notNull(indexType, "IndexType must not be null");
this.indexType = indexType;
return this;
}

@Override
public CoherenceVectorStore build() {
validate();
return new CoherenceVectorStore(this);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.vectorstore.coherence;

import com.tangosol.util.Filters;
import com.tangosol.util.ValueExtractor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.vectorstore.coherence;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -50,6 +50,8 @@
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringBootConfiguration;
Expand Down Expand Up @@ -309,10 +311,14 @@ public static class TestClient {

@Bean
public VectorStore vectorStore(EmbeddingModel embeddingModel, Session session) {
return new CoherenceVectorStore(embeddingModel, session).setDistanceType(this.distanceType)
.setIndexType(this.indexType)
.setForcedNormalization(this.distanceType == CoherenceVectorStore.DistanceType.COSINE
|| this.distanceType == CoherenceVectorStore.DistanceType.IP);
return CoherenceVectorStore.builder()
.embeddingModel(embeddingModel)
.session(session)
.distanceType(this.distanceType)
.indexType(this.indexType)
.forcedNormalization(this.distanceType == CoherenceVectorStore.DistanceType.COSINE
|| this.distanceType == CoherenceVectorStore.DistanceType.IP)
.build();
}

@Bean
Expand Down
Loading