Skip to content

Commit d4a6e1e

Browse files
sobychackomarkpollack
authored andcommitted
CosmosDB vector store builder refactoring
1 parent 138e11a commit d4a6e1e

File tree

7 files changed

+292
-29
lines changed

7 files changed

+292
-29
lines changed

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/azure-cosmos-db.adoc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ import io.micrometer.observation.ObservationRegistry;
157157
import org.springframework.ai.document.Document;
158158
import org.springframework.ai.embedding.EmbeddingModel;
159159
import org.springframework.ai.transformers.TransformersEmbeddingModel;
160-
import org.springframework.ai.vectorstore.CosmosDBVectorStore;
160+
import org.springframework.ai.vectorstore.cosmosdb.CosmosDBVectorStore;
161161
import org.springframework.ai.vectorstore.CosmosDBVectorStoreConfig;
162162
import org.springframework.ai.vectorstore.VectorStore;
163163
import org.springframework.beans.factory.annotation.Autowired;
@@ -236,4 +236,4 @@ Add the following dependency in your Maven project:
236236
<groupId>org.springframework.ai</groupId>
237237
<artifactId>spring-ai-azure-cosmos-db-store</artifactId>
238238
</dependency>
239-
----
239+
----

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import org.springframework.ai.embedding.BatchingStrategy;
2424
import org.springframework.ai.embedding.EmbeddingModel;
2525
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
26-
import org.springframework.ai.vectorstore.CosmosDBVectorStore;
27-
import org.springframework.ai.vectorstore.CosmosDBVectorStoreConfig;
26+
import org.springframework.ai.vectorstore.cosmosdb.CosmosDBVectorStore;
27+
import org.springframework.ai.vectorstore.cosmosdb.CosmosDBVectorStoreConfig;
2828
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
2929
import org.springframework.beans.factory.ObjectProvider;
3030
import org.springframework.boot.autoconfigure.AutoConfiguration;
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.vectorstore;
17+
package org.springframework.ai.vectorstore.cosmosdb;
1818

1919
import java.util.Collection;
2020
import java.util.Map;
Lines changed: 203 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.ai.vectorstore;
17+
package org.springframework.ai.vectorstore.cosmosdb;
1818

1919
import java.util.ArrayList;
2020
import java.util.Collections;
@@ -63,10 +63,13 @@
6363
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
6464
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
6565
import org.springframework.ai.observation.conventions.VectorStoreProvider;
66+
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
67+
import org.springframework.ai.vectorstore.SearchRequest;
6668
import org.springframework.ai.vectorstore.filter.Filter;
6769
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
6870
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
6971
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
72+
import org.springframework.util.Assert;
7073

7174
/**
7275
* Cosmos DB implementation.
@@ -82,35 +85,94 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen
8285

8386
private final CosmosAsyncClient cosmosClient;
8487

85-
private final EmbeddingModel embeddingModel;
88+
private final String containerName;
8689

87-
private final CosmosDBVectorStoreConfig properties;
90+
private final String databaseName;
91+
92+
private final String partitionKeyPath;
93+
94+
private final int vectorStoreThroughput;
95+
96+
private final long vectorDimensions;
97+
98+
private final List<String> metadataFieldsList;
8899

89100
private final BatchingStrategy batchingStrategy;
90101

91102
private CosmosAsyncContainer container;
92103

104+
/**
105+
* Creates a new CosmosDBVectorStore with basic configuration.
106+
* @param observationRegistry the observation registry
107+
* @param customObservationConvention the custom observation convention
108+
* @param cosmosClient the Cosmos DB client
109+
* @param properties the configuration properties
110+
* @param embeddingModel the embedding model
111+
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
112+
*/
113+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
93114
public CosmosDBVectorStore(ObservationRegistry observationRegistry,
94115
VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient,
95116
CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel) {
96117
this(observationRegistry, customObservationConvention, cosmosClient, properties, embeddingModel,
97118
new TokenCountBatchingStrategy());
98119
}
99120

121+
/**
122+
* Creates a new CosmosDBVectorStore with full configuration.
123+
* @param observationRegistry the observation registry
124+
* @param customObservationConvention the custom observation convention
125+
* @param cosmosClient the Cosmos DB client
126+
* @param properties the configuration properties
127+
* @param embeddingModel the embedding model
128+
* @param batchingStrategy the batching strategy
129+
* @deprecated Since 1.0.0-M5, use {@link #builder()} instead
130+
*/
131+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
100132
public CosmosDBVectorStore(ObservationRegistry observationRegistry,
101133
VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient,
102134
CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) {
103-
super(observationRegistry, customObservationConvention);
104-
this.cosmosClient = cosmosClient;
105-
this.properties = properties;
106-
this.batchingStrategy = batchingStrategy;
107-
cosmosClient.createDatabaseIfNotExists(properties.getDatabaseName()).block();
135+
this(builder().cosmosClient(cosmosClient)
136+
.embeddingModel(embeddingModel)
137+
.containerName(properties.getContainerName())
138+
.databaseName(properties.getDatabaseName())
139+
.partitionKeyPath(properties.getPartitionKeyPath())
140+
.vectorStoreThroughput(properties.getVectorStoreThroughput())
141+
.vectorDimensions(properties.getVectorDimensions())
142+
.metadataFields(properties.getMetadataFieldsList())
143+
.observationRegistry(observationRegistry)
144+
.customObservationConvention(customObservationConvention)
145+
.batchingStrategy(batchingStrategy));
146+
}
108147

109-
initializeContainer(properties.getContainerName(), properties.getDatabaseName(),
110-
properties.getVectorStoreThroughput(), properties.getVectorDimensions(),
111-
properties.getPartitionKeyPath());
148+
/**
149+
* Protected constructor that accepts a builder instance. This is the preferred way to
150+
* create new CosmosDBVectorStore instances.
151+
* @param builder the configured builder instance
152+
*/
153+
protected CosmosDBVectorStore(CosmosDBBuilder builder) {
154+
super(builder);
155+
156+
Assert.notNull(builder.cosmosClient, "CosmosClient must not be null");
157+
Assert.hasText(builder.containerName, "Container name must not be empty");
158+
Assert.hasText(builder.databaseName, "Database name must not be empty");
159+
Assert.hasText(builder.partitionKeyPath, "Partition key path must not be empty");
160+
161+
this.cosmosClient = builder.cosmosClient;
162+
this.containerName = builder.containerName;
163+
this.databaseName = builder.databaseName;
164+
this.partitionKeyPath = builder.partitionKeyPath;
165+
this.vectorStoreThroughput = builder.vectorStoreThroughput;
166+
this.vectorDimensions = builder.vectorDimensions;
167+
this.metadataFieldsList = builder.metadataFieldsList;
168+
this.batchingStrategy = builder.batchingStrategy;
169+
170+
cosmosClient.createDatabaseIfNotExists(databaseName).block();
171+
initializeContainer(containerName, databaseName, vectorStoreThroughput, vectorDimensions, partitionKeyPath);
172+
}
112173

113-
this.embeddingModel = embeddingModel;
174+
public static CosmosDBBuilder builder() {
175+
return new CosmosDBBuilder();
114176
}
115177

116178
private void initializeContainer(String containerName, String databaseName, int vectorStoreThroughput,
@@ -308,7 +370,7 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
308370
Filter.Expression filterExpression = request.getFilterExpression();
309371
if (filterExpression != null) {
310372
CosmosDBFilterExpressionConverter filterExpressionConverter = new CosmosDBFilterExpressionConverter(
311-
this.properties.getMetadataFieldsList()); // Use the expression
373+
this.metadataFieldsList); // Use the expression
312374
// directly as
313375
// it handles the
314376
// "metadata"
@@ -359,4 +421,132 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
359421
.similarityMetric("cosine");
360422
}
361423

424+
/**
425+
* Builder class for creating {@link CosmosDBVectorStore} instances.
426+
* <p>
427+
* Provides a fluent API for configuring all aspects of the Cosmos DB vector store.
428+
*
429+
* @since 1.0.0
430+
*/
431+
public static class CosmosDBBuilder extends AbstractVectorStoreBuilder<CosmosDBBuilder> {
432+
433+
private CosmosAsyncClient cosmosClient;
434+
435+
private String containerName;
436+
437+
private String databaseName;
438+
439+
private String partitionKeyPath;
440+
441+
private int vectorStoreThroughput = 400;
442+
443+
private long vectorDimensions = 1536;
444+
445+
private List<String> metadataFieldsList = new ArrayList<>();
446+
447+
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
448+
449+
/**
450+
* Sets the Cosmos DB client.
451+
* @param cosmosClient the client to use
452+
* @return the builder instance
453+
* @throws IllegalArgumentException if cosmosClient is null
454+
*/
455+
public CosmosDBBuilder cosmosClient(CosmosAsyncClient cosmosClient) {
456+
Assert.notNull(cosmosClient, "CosmosClient must not be null");
457+
this.cosmosClient = cosmosClient;
458+
return this;
459+
}
460+
461+
/**
462+
* Sets the container name.
463+
* @param containerName the name of the container
464+
* @return the builder instance
465+
* @throws IllegalArgumentException if containerName is null or empty
466+
*/
467+
public CosmosDBBuilder containerName(String containerName) {
468+
Assert.hasText(containerName, "Container name must not be empty");
469+
this.containerName = containerName;
470+
return this;
471+
}
472+
473+
/**
474+
* Sets the database name.
475+
* @param databaseName the name of the database
476+
* @return the builder instance
477+
* @throws IllegalArgumentException if databaseName is null or empty
478+
*/
479+
public CosmosDBBuilder databaseName(String databaseName) {
480+
Assert.hasText(databaseName, "Database name must not be empty");
481+
this.databaseName = databaseName;
482+
return this;
483+
}
484+
485+
/**
486+
* Sets the partition key path.
487+
* @param partitionKeyPath the partition key path
488+
* @return the builder instance
489+
* @throws IllegalArgumentException if partitionKeyPath is null or empty
490+
*/
491+
public CosmosDBBuilder partitionKeyPath(String partitionKeyPath) {
492+
Assert.hasText(partitionKeyPath, "Partition key path must not be empty");
493+
this.partitionKeyPath = partitionKeyPath;
494+
return this;
495+
}
496+
497+
/**
498+
* Sets the vector store throughput.
499+
* @param vectorStoreThroughput the throughput value
500+
* @return the builder instance
501+
* @throws IllegalArgumentException if vectorStoreThroughput is not positive
502+
*/
503+
public CosmosDBBuilder vectorStoreThroughput(int vectorStoreThroughput) {
504+
Assert.isTrue(vectorStoreThroughput > 0, "Vector store throughput must be positive");
505+
this.vectorStoreThroughput = vectorStoreThroughput;
506+
return this;
507+
}
508+
509+
/**
510+
* Sets the vector dimensions.
511+
* @param vectorDimensions the number of dimensions
512+
* @return the builder instance
513+
* @throws IllegalArgumentException if vectorDimensions is not positive
514+
*/
515+
public CosmosDBBuilder vectorDimensions(long vectorDimensions) {
516+
Assert.isTrue(vectorDimensions > 0, "Vector dimensions must be positive");
517+
this.vectorDimensions = vectorDimensions;
518+
return this;
519+
}
520+
521+
/**
522+
* Sets the metadata fields list.
523+
* @param metadataFieldsList the list of metadata fields
524+
* @return the builder instance
525+
*/
526+
public CosmosDBBuilder metadataFields(List<String> metadataFieldsList) {
527+
this.metadataFieldsList = metadataFieldsList != null ? new ArrayList<>(metadataFieldsList)
528+
: new ArrayList<>();
529+
return this;
530+
}
531+
532+
/**
533+
* Sets the batching strategy.
534+
* @param batchingStrategy the strategy to use
535+
* @return the builder instance
536+
* @throws IllegalArgumentException if batchingStrategy is null
537+
*/
538+
public CosmosDBBuilder batchingStrategy(BatchingStrategy batchingStrategy) {
539+
Assert.notNull(batchingStrategy, "BatchingStrategy must not be null");
540+
this.batchingStrategy = batchingStrategy;
541+
return this;
542+
}
543+
544+
@Override
545+
public CosmosDBVectorStore build() {
546+
validate();
547+
return new CosmosDBVectorStore(this);
548+
}
549+
550+
}
551+
362552
}

0 commit comments

Comments
 (0)