Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.CosmosDBVectorStore;
import org.springframework.ai.vectorstore.cosmosdb.CosmosDBVectorStore;
import org.springframework.ai.vectorstore.CosmosDBVectorStoreConfig;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -236,4 +236,4 @@ Add the following dependency in your Maven project:
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-azure-cosmos-db-store</artifactId>
</dependency>
----
----
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.CosmosDBVectorStore;
import org.springframework.ai.vectorstore.CosmosDBVectorStoreConfig;
import org.springframework.ai.vectorstore.cosmosdb.CosmosDBVectorStore;
import org.springframework.ai.vectorstore.cosmosdb.CosmosDBVectorStoreConfig;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.vectorstore.cosmosdb;

import java.util.Collection;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.vectorstore.cosmosdb;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -64,10 +63,13 @@
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.util.Assert;

/**
* Cosmos DB implementation.
Expand All @@ -83,35 +85,94 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen

private final CosmosAsyncClient cosmosClient;

private final EmbeddingModel embeddingModel;
private final String containerName;

private final CosmosDBVectorStoreConfig properties;
private final String databaseName;

private final String partitionKeyPath;

private final int vectorStoreThroughput;

private final long vectorDimensions;

private final List<String> metadataFieldsList;

private final BatchingStrategy batchingStrategy;

private CosmosAsyncContainer container;

/**
* Creates a new CosmosDBVectorStore with basic configuration.
* @param observationRegistry the observation registry
* @param customObservationConvention the custom observation convention
* @param cosmosClient the Cosmos DB client
* @param properties the configuration properties
* @param embeddingModel the embedding model
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public CosmosDBVectorStore(ObservationRegistry observationRegistry,
VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient,
CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel) {
this(observationRegistry, customObservationConvention, cosmosClient, properties, embeddingModel,
new TokenCountBatchingStrategy());
}

/**
* Creates a new CosmosDBVectorStore with full configuration.
* @param observationRegistry the observation registry
* @param customObservationConvention the custom observation convention
* @param cosmosClient the Cosmos DB client
* @param properties the configuration properties
* @param embeddingModel the embedding model
* @param batchingStrategy the batching strategy
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public CosmosDBVectorStore(ObservationRegistry observationRegistry,
VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient,
CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) {
super(observationRegistry, customObservationConvention);
this.cosmosClient = cosmosClient;
this.properties = properties;
this.batchingStrategy = batchingStrategy;
cosmosClient.createDatabaseIfNotExists(properties.getDatabaseName()).block();
this(builder().cosmosClient(cosmosClient)
.embeddingModel(embeddingModel)
.containerName(properties.getContainerName())
.databaseName(properties.getDatabaseName())
.partitionKeyPath(properties.getPartitionKeyPath())
.vectorStoreThroughput(properties.getVectorStoreThroughput())
.vectorDimensions(properties.getVectorDimensions())
.metadataFields(properties.getMetadataFieldsList())
.observationRegistry(observationRegistry)
.customObservationConvention(customObservationConvention)
.batchingStrategy(batchingStrategy));
}

initializeContainer(properties.getContainerName(), properties.getDatabaseName(),
properties.getVectorStoreThroughput(), properties.getVectorDimensions(),
properties.getPartitionKeyPath());
/**
* Protected constructor that accepts a builder instance. This is the preferred way to
* create new CosmosDBVectorStore instances.
* @param builder the configured builder instance
*/
protected CosmosDBVectorStore(CosmosDBBuilder builder) {
super(builder);

Assert.notNull(builder.cosmosClient, "CosmosClient must not be null");
Assert.hasText(builder.containerName, "Container name must not be empty");
Assert.hasText(builder.databaseName, "Database name must not be empty");
Assert.hasText(builder.partitionKeyPath, "Partition key path must not be empty");

this.cosmosClient = builder.cosmosClient;
this.containerName = builder.containerName;
this.databaseName = builder.databaseName;
this.partitionKeyPath = builder.partitionKeyPath;
this.vectorStoreThroughput = builder.vectorStoreThroughput;
this.vectorDimensions = builder.vectorDimensions;
this.metadataFieldsList = builder.metadataFieldsList;
this.batchingStrategy = builder.batchingStrategy;

cosmosClient.createDatabaseIfNotExists(databaseName).block();
initializeContainer(containerName, databaseName, vectorStoreThroughput, vectorDimensions, partitionKeyPath);
}

this.embeddingModel = embeddingModel;
public static CosmosDBBuilder builder() {
return new CosmosDBBuilder();
}

private void initializeContainer(String containerName, String databaseName, int vectorStoreThroughput,
Expand Down Expand Up @@ -309,7 +370,7 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
Filter.Expression filterExpression = request.getFilterExpression();
if (filterExpression != null) {
CosmosDBFilterExpressionConverter filterExpressionConverter = new CosmosDBFilterExpressionConverter(
this.properties.getMetadataFieldsList()); // Use the expression
this.metadataFieldsList); // Use the expression
// directly as
// it handles the
// "metadata"
Expand Down Expand Up @@ -360,4 +421,132 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
.withSimilarityMetric("cosine");
}

/**
* Builder class for creating {@link CosmosDBVectorStore} instances.
* <p>
* Provides a fluent API for configuring all aspects of the Cosmos DB vector store.
*
* @since 1.0.0
*/
public static class CosmosDBBuilder extends AbstractVectorStoreBuilder<CosmosDBBuilder> {

private CosmosAsyncClient cosmosClient;

private String containerName;

private String databaseName;

private String partitionKeyPath;

private int vectorStoreThroughput = 400;

private long vectorDimensions = 1536;

private List<String> metadataFieldsList = new ArrayList<>();

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

/**
* Sets the Cosmos DB client.
* @param cosmosClient the client to use
* @return the builder instance
* @throws IllegalArgumentException if cosmosClient is null
*/
public CosmosDBBuilder cosmosClient(CosmosAsyncClient cosmosClient) {
Assert.notNull(cosmosClient, "CosmosClient must not be null");
this.cosmosClient = cosmosClient;
return this;
}

/**
* Sets the container name.
* @param containerName the name of the container
* @return the builder instance
* @throws IllegalArgumentException if containerName is null or empty
*/
public CosmosDBBuilder containerName(String containerName) {
Assert.hasText(containerName, "Container name must not be empty");
this.containerName = containerName;
return this;
}

/**
* Sets the database name.
* @param databaseName the name of the database
* @return the builder instance
* @throws IllegalArgumentException if databaseName is null or empty
*/
public CosmosDBBuilder databaseName(String databaseName) {
Assert.hasText(databaseName, "Database name must not be empty");
this.databaseName = databaseName;
return this;
}

/**
* Sets the partition key path.
* @param partitionKeyPath the partition key path
* @return the builder instance
* @throws IllegalArgumentException if partitionKeyPath is null or empty
*/
public CosmosDBBuilder partitionKeyPath(String partitionKeyPath) {
Assert.hasText(partitionKeyPath, "Partition key path must not be empty");
this.partitionKeyPath = partitionKeyPath;
return this;
}

/**
* Sets the vector store throughput.
* @param vectorStoreThroughput the throughput value
* @return the builder instance
* @throws IllegalArgumentException if vectorStoreThroughput is not positive
*/
public CosmosDBBuilder vectorStoreThroughput(int vectorStoreThroughput) {
Assert.isTrue(vectorStoreThroughput > 0, "Vector store throughput must be positive");
this.vectorStoreThroughput = vectorStoreThroughput;
return this;
}

/**
* Sets the vector dimensions.
* @param vectorDimensions the number of dimensions
* @return the builder instance
* @throws IllegalArgumentException if vectorDimensions is not positive
*/
public CosmosDBBuilder vectorDimensions(long vectorDimensions) {
Assert.isTrue(vectorDimensions > 0, "Vector dimensions must be positive");
this.vectorDimensions = vectorDimensions;
return this;
}

/**
* Sets the metadata fields list.
* @param metadataFieldsList the list of metadata fields
* @return the builder instance
*/
public CosmosDBBuilder metadataFields(List<String> metadataFieldsList) {
this.metadataFieldsList = metadataFieldsList != null ? new ArrayList<>(metadataFieldsList)
: new ArrayList<>();
return this;
}

/**
* Sets the batching strategy.
* @param batchingStrategy the strategy to use
* @return the builder instance
* @throws IllegalArgumentException if batchingStrategy is null
*/
public CosmosDBBuilder batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return this;
}

@Override
public CosmosDBVectorStore build() {
validate();
return new CosmosDBVectorStore(this);
}

}

}
Loading
Loading