diff --git a/vector-stores/spring-ai-chroma-store/pom.xml b/vector-stores/spring-ai-chroma-store/pom.xml index 3e7b35c4a8b..88f6738c4af 100644 --- a/vector-stores/spring-ai-chroma-store/pom.xml +++ b/vector-stores/spring-ai-chroma-store/pom.xml @@ -87,6 +87,13 @@ micrometer-observation-test test + + + org.springframework.ai + spring-ai-transformers + ${parent.version} + test + diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 90899516ae9..ba1166927c6 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -41,7 +41,6 @@ import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; -import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; @@ -58,7 +57,7 @@ * @author Christian Tzolov * @author Fu Cheng * @author Sebastien Deleuze - * + * @author Soby Chacko */ public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -66,10 +65,6 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection"; - public static final double SIMILARITY_THRESHOLD_ALL = 0.0; - - public static final int DEFAULT_TOP_K = 4; - private final EmbeddingModel embeddingModel; private final ChromaApi chromaApi; @@ -86,6 +81,8 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements private final ObjectMapper objectMapper; + private boolean initialized = false; + public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, boolean initializeSchema) { this(embeddingModel, chromaApi, DEFAULT_COLLECTION_NAME, initializeSchema); } @@ -111,6 +108,26 @@ public ChromaVectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi, Str this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build(); } + private ChromaVectorStore(Builder builder) { + super(builder.observationRegistry, builder.customObservationConvention); + this.embeddingModel = builder.embeddingModel; + this.chromaApi = builder.chromaApi; + this.collectionName = builder.collectionName; + this.initializeSchema = builder.initializeSchema; + this.filterExpressionConverter = builder.filterExpressionConverter; + this.batchingStrategy = builder.batchingStrategy; + this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build(); + + if (builder.initializeImmediately) { + try { + afterPropertiesSet(); + } + catch (Exception e) { + throw new IllegalStateException("Failed to initialize ChromaVectorStore", e); + } + } + } + public void setFilterExpressionConverter(FilterExpressionConverter filterExpressionConverter) { Assert.notNull(filterExpressionConverter, "FilterExpressionConverter should not be null."); this.filterExpressionConverter = filterExpressionConverter; @@ -207,26 +224,95 @@ public String getCollectionId() { @Override public void afterPropertiesSet() throws Exception { - var collection = this.chromaApi.getCollection(this.collectionName); - if (collection == null) { - if (this.initializeSchema) { - collection = this.chromaApi - .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName)); - } - else { - throw new RuntimeException("Collection " + this.collectionName - + " doesn't exist and won't be created as the initializeSchema is set to false."); + if (!this.initialized) { + var collection = this.chromaApi.getCollection(this.collectionName); + if (collection == null) { + if (this.initializeSchema) { + collection = this.chromaApi + .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName)); + } + else { + throw new RuntimeException("Collection " + this.collectionName + + " doesn't exist and won't be created as the initializeSchema is set to false."); + } } + this.collectionId = collection.id(); + this.initialized = true; } - this.collectionId = collection.id(); } @Override - public Builder createObservationContextBuilder(String operationName) { + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName) .withDimensions(this.embeddingModel.dimensions()) .withCollectionName(this.collectionName + ":" + this.collectionId) .withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null); } + public static class Builder { + + private final EmbeddingModel embeddingModel; + + private final ChromaApi chromaApi; + + private String collectionName = DEFAULT_COLLECTION_NAME; + + private boolean initializeSchema = false; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private VectorStoreObservationConvention customObservationConvention = null; + + private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + + private FilterExpressionConverter filterExpressionConverter = new ChromaFilterExpressionConverter(); + + private boolean initializeImmediately = false; + + public Builder(EmbeddingModel embeddingModel, ChromaApi chromaApi) { + this.embeddingModel = embeddingModel; + this.chromaApi = chromaApi; + } + + public Builder collectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + public Builder initializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public Builder customObservationConvention(VectorStoreObservationConvention convention) { + this.customObservationConvention = convention; + return this; + } + + public Builder batchingStrategy(BatchingStrategy batchingStrategy) { + this.batchingStrategy = batchingStrategy; + return this; + } + + public Builder filterExpressionConverter(FilterExpressionConverter converter) { + this.filterExpressionConverter = converter; + return this; + } + + public Builder initializeImmediately(boolean initialize) { + this.initializeImmediately = initialize; + return this; + } + + public ChromaVectorStore build() { + return new ChromaVectorStore(this); + } + + } + } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java index bfc84340294..0b01ae35486 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java @@ -16,6 +16,7 @@ package org.springframework.ai.chroma; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -30,17 +31,24 @@ import org.springframework.ai.chroma.ChromaApi.Collection; import org.springframework.ai.chroma.ChromaApi.GetEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.QueryRequest; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.transformers.TransformersEmbeddingModel; +import org.springframework.ai.vectorstore.ChromaVectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /** * @author Christian Tzolov * @author EddĂș MelĂ©ndez * @author Thomas Vitale + * @author Soby Chacko */ @SpringBootTest @Testcontainers @@ -50,17 +58,20 @@ public class ChromaApiIT { static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE); @Autowired - ChromaApi chroma; + ChromaApi chromaApi; + + @Autowired + EmbeddingModel embeddingModel; @BeforeEach public void beforeEach() { - this.chroma.listCollections().stream().forEach(c -> this.chroma.deleteCollection(c.name())); + this.chromaApi.listCollections().stream().forEach(c -> this.chromaApi.deleteCollection(c.name())); } @Test public void testClientWithMetadata() { Map metadata = Map.of("hnsw:space", "cosine", "hnsw:M", 5); - var newCollection = this.chroma + var newCollection = this.chromaApi .createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata)); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); @@ -68,44 +79,44 @@ public void testClientWithMetadata() { @Test public void testClient() { - var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); - var getCollection = this.chroma.getCollection("TestCollection"); + var getCollection = this.chromaApi.getCollection("TestCollection"); assertThat(getCollection).isNotNull(); assertThat(getCollection.name()).isEqualTo("TestCollection"); assertThat(getCollection.id()).isEqualTo(newCollection.id()); - List collections = this.chroma.listCollections(); + List collections = this.chromaApi.listCollections(); assertThat(collections).hasSize(1); assertThat(collections.get(0).id()).isEqualTo(newCollection.id()); - this.chroma.deleteCollection(newCollection.name()); - assertThat(this.chroma.listCollections()).hasSize(0); + this.chromaApi.deleteCollection(newCollection.name()); + assertThat(this.chromaApi.listCollections()).hasSize(0); } @Test public void testCollection() { - var newCollection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); - assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(0); + var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(0); var addEmbeddingRequest = new AddEmbeddingsRequest(List.of("id1", "id2"), List.of(new float[] { 1f, 1f, 1f }, new float[] { 2f, 2f, 2f }), List.of(Map.of(), Map.of("key1", "value1", "key2", true, "key3", 23.4)), List.of("Hello World", "Big World")); - this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); + this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); var addEmbeddingRequest2 = new AddEmbeddingsRequest("id3", new float[] { 3f, 3f, 3f }, Map.of("key1", "value1", "key2", true, "key3", 23.4), "Big World"); - this.chroma.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); + this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); - assertThat(this.chroma.countEmbeddings(newCollection.id())).isEqualTo(3); + assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(3); - var queryResult = this.chroma.queryCollection(newCollection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" + var queryResult = this.chromaApi.queryCollection(newCollection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "key2" : { "$eq": true } } @@ -114,14 +125,14 @@ public void testCollection() { assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id2", "id3"); // Update existing embedding. - this.chroma.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, + this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World")); - var result = this.chroma.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); + var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); assertThat(result.ids().get(0)).isEqualTo("id2"); - queryResult = this.chroma.queryCollection(newCollection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" + queryResult = this.chromaApi.queryCollection(newCollection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "key2" : { "$eq": true } } @@ -133,7 +144,7 @@ public void testCollection() { @Test public void testQueryWhere() { - var collection = this.chroma.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); var add1 = new AddEmbeddingsRequest("id1", new float[] { 1f, 1f, 1f }, Map.of("country", "BG", "active", true, "price", 23.4, "year", 2020), @@ -146,24 +157,25 @@ public void testQueryWhere() { Map.of("country", "BG", "active", false, "price", 40.1, "year", 2023), "The World is Big and Salvation Lurks Around the Corner"); - this.chroma.upsertEmbeddings(collection.id(), add1); - this.chroma.upsertEmbeddings(collection.id(), add2); - this.chroma.upsertEmbeddings(collection.id(), add3); + this.chromaApi.upsertEmbeddings(collection.id(), add1); + this.chromaApi.upsertEmbeddings(collection.id(), add2); + this.chromaApi.upsertEmbeddings(collection.id(), add3); - assertThat(this.chroma.countEmbeddings(collection.id())).isEqualTo(3); + assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3); - var queryResult = this.chroma.queryCollection(collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); + var queryResult = this.chromaApi.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); assertThat(queryResult.ids().get(0)).hasSize(3); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3"); - var chromaEmbeddings = this.chroma.toEmbeddingResponseList(queryResult); + var chromaEmbeddings = this.chromaApi.toEmbeddingResponseList(queryResult); assertThat(chromaEmbeddings).hasSize(3); assertThat(chromaEmbeddings).hasSize(3); - queryResult = this.chroma.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" + queryResult = this.chromaApi.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "$and" : [ {"country" : { "$eq": "BG"}}, @@ -174,8 +186,8 @@ public void testQueryWhere() { assertThat(queryResult.ids().get(0)).hasSize(2); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id3"); - queryResult = this.chroma.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chroma.where(""" + queryResult = this.chromaApi.queryCollection(collection.id(), + new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "$and" : [ {"country" : { "$eq": "BG"}}, @@ -188,6 +200,53 @@ public void testQueryWhere() { assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1"); } + @Test + void shouldUseExistingCollectionWhenSchemaInitializationDisabled() { // initializeSchema + // is false by + // default. + var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("test-collection")); + assertThat(collection).isNotNull(); + assertThat(collection.name()).isEqualTo("test-collection"); + + ChromaVectorStore store = new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi) + .collectionName("test-collection") + .initializeImmediately(true) + .build(); + + Document document = new Document("test content"); + assertThatNoException().isThrownBy(() -> store.add(Collections.singletonList(document))); + } + + @Test + void shouldCreateNewCollectionWhenSchemaInitializationEnabled() { + ChromaVectorStore store = new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi) + .collectionName("new-collection") + .initializeSchema(true) + .initializeImmediately(true) + .build(); + + var collection = this.chromaApi.getCollection("new-collection"); + assertThat(collection).isNotNull(); + assertThat(collection.name()).isEqualTo("new-collection"); + + Document document = new Document("test content"); + assertThatNoException().isThrownBy(() -> store.add(Collections.singletonList(document))); + } + + @Test + void shouldFailWhenCollectionDoesNotExist() { + assertThatThrownBy( + () -> new ChromaVectorStore.Builder(this.embeddingModel, this.chromaApi).collectionName("non-existent") + .initializeSchema(false) + .initializeImmediately(true) + .build()) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Failed to initialize ChromaVectorStore") + .hasCauseInstanceOf(RuntimeException.class) + .hasRootCauseMessage( + "Collection non-existent doesn't exist and won't be created as the initializeSchema is set to false."); + } + @SpringBootConfiguration public static class Config { @@ -196,6 +255,11 @@ public ChromaApi chromaApi() { return new ChromaApi(chromaContainer.getEndpoint()); } + @Bean + public EmbeddingModel embeddingModel() { + return new TransformersEmbeddingModel(); + } + } }