Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vector-stores/spring-ai-chroma-store/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-transformers</artifactId>
<version>${parent.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
Expand All @@ -58,18 +57,14 @@
* @author Christian Tzolov
* @author Fu Cheng
* @author Sebastien Deleuze
*
* @author Soby Chacko
*/
public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean {

public static final String DISTANCE_FIELD_NAME = "distance";

public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection";

public static final double SIMILARITY_THRESHOLD_ALL = 0.0;

public static final int DEFAULT_TOP_K = 4;

private final EmbeddingModel embeddingModel;

private final ChromaApi chromaApi;
Expand All @@ -86,6 +81,8 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements

private final ObjectMapper objectMapper;

private boolean initialized = false;

public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, boolean initializeSchema) {
this(embeddingModel, chromaApi, DEFAULT_COLLECTION_NAME, initializeSchema);
}
Expand All @@ -111,6 +108,26 @@ public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, Str
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
}

private ChromaVectorStore(Builder builder) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is yet another variation of the builder pattern in our docs, any reason not to have this code inside the builder's build method?

super(builder.observationRegistry, builder.customObservationConvention);
this.embeddingModel = builder.embeddingModel;
this.chromaApi = builder.chromaApi;
this.collectionName = builder.collectionName;
this.initializeSchema = builder.initializeSchema;
this.filterExpressionConverter = builder.filterExpressionConverter;
this.batchingStrategy = builder.batchingStrategy;
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();

if (builder.initializeImmediately) {
try {
afterPropertiesSet();
}
catch (Exception e) {
throw new IllegalStateException("Failed to initialize ChromaVectorStore", e);
}
}
}

public void setFilterExpressionConverter(FilterExpressionConverter filterExpressionConverter) {
Assert.notNull(filterExpressionConverter, "FilterExpressionConverter should not be null.");
this.filterExpressionConverter = filterExpressionConverter;
Expand Down Expand Up @@ -207,26 +224,95 @@ public String getCollectionId() {

@Override
public void afterPropertiesSet() throws Exception {
var collection = this.chromaApi.getCollection(this.collectionName);
if (collection == null) {
if (this.initializeSchema) {
collection = this.chromaApi
.createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName));
}
else {
throw new RuntimeException("Collection " + this.collectionName
+ " doesn't exist and won't be created as the initializeSchema is set to false.");
if (!this.initialized) {
var collection = this.chromaApi.getCollection(this.collectionName);
if (collection == null) {
if (this.initializeSchema) {
collection = this.chromaApi
.createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName));
}
else {
throw new RuntimeException("Collection " + this.collectionName
+ " doesn't exist and won't be created as the initializeSchema is set to false.");
}
}
this.collectionId = collection.id();
this.initialized = true;
}
this.collectionId = collection.id();
}

@Override
public Builder createObservationContextBuilder(String operationName) {
public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName)
.withDimensions(this.embeddingModel.dimensions())
.withCollectionName(this.collectionName + ":" + this.collectionId)
.withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null);
}

public static class Builder {

private final EmbeddingModel embeddingModel;

private final ChromaApi chromaApi;

private String collectionName = DEFAULT_COLLECTION_NAME;

private boolean initializeSchema = false;

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private VectorStoreObservationConvention customObservationConvention = null;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private FilterExpressionConverter filterExpressionConverter = new ChromaFilterExpressionConverter();

private boolean initializeImmediately = false;

public Builder(EmbeddingModel embeddingModel, ChromaApi chromaApi) {
this.embeddingModel = embeddingModel;
this.chromaApi = chromaApi;
}

public Builder collectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}

public Builder initializeSchema(boolean initializeSchema) {
this.initializeSchema = initializeSchema;
return this;
}

public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
}

public Builder customObservationConvention(VectorStoreObservationConvention convention) {
this.customObservationConvention = convention;
return this;
}

public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
this.batchingStrategy = batchingStrategy;
return this;
}

public Builder filterExpressionConverter(FilterExpressionConverter converter) {
this.filterExpressionConverter = converter;
return this;
}

public Builder initializeImmediately(boolean initialize) {
this.initializeImmediately = initialize;
return this;
}

public ChromaVectorStore build() {
return new ChromaVectorStore(this);
}

}

}
Loading