diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc index c1378069f9f..49cbdd1b5a3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc @@ -265,7 +265,7 @@ If all goes well, you should retrieve the document containing the text "Spring A === Run Chroma Locally ```shell -docker run -it --rm --name chroma -p 8000:8000 ghcr.io/chroma-core/chroma:0.4.15 +docker run -it --rm --name chroma -p 8000:8000 ghcr.io/chroma-core/chroma:0.5.20 ``` Starts a chroma store at diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java index b33f714cb67..4998f097e32 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java @@ -55,7 +55,7 @@ public class ChromaVectorStoreAutoConfigurationIT { @Container - static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.0"); + static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.20"); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(ChromaVectorStoreAutoConfiguration.class)) diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java index 47491a9f80c..3655f350ebc 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java @@ -23,7 +23,7 @@ */ public final class ChromaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20"); private ChromaImage() { diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java index 743a4670c0c..131191cee01 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -23,6 +23,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -123,17 +124,6 @@ public Collection createCollection(CreateCollectionRequest createCollectionReque .getBody(); } - public Map createCollection2(CreateCollectionRequest createCollectionRequest) { - - return this.restClient.post() - .uri("/api/v1/collections") - .headers(this::httpHeaders) - .body(createCollectionRequest) - .retrieve() - .toEntity(Map.class) - .getBody(); - } - /** * Delete a collection with the given name. * @param collectionName the name of the collection to delete. @@ -281,7 +271,11 @@ private String getErrorMessage(HttpStatusCodeException e) { * @param name The name of the collection. * @param metadata Metadata associated with the collection. */ - public record Collection(String id, String name, Map metadata) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Collection(// @formatter:off + @JsonProperty("id") String id, + @JsonProperty("name") String name, + @JsonProperty("metadata") Map metadata) { // @formatter:on } @@ -291,7 +285,10 @@ public record Collection(String id, String name, Map metadata) { * @param name The name of the collection to create. * @param metadata Optional metadata to associate with the collection. */ - public record CreateCollectionRequest(String name, Map metadata) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record CreateCollectionRequest(// @formatter:off + @JsonProperty("name") String name, + @JsonProperty("metadata") Map metadata) {// @formatter:on public CreateCollectionRequest(String name) { this(name, new HashMap<>(Map.of("hnsw:space", "cosine"))); @@ -300,7 +297,7 @@ public CreateCollectionRequest(String name) { } // - // Chroma Collection API (https://docs.trychroma.com/js_reference/Collection) + // Chroma Collection API (https://docs.trychroma.com/reference/js-client/Collection) // /** @@ -312,14 +309,17 @@ public CreateCollectionRequest(String name) { * can filter on this metadata. * @param documents The documents contents to associate with the embeddings. */ - public record AddEmbeddingsRequest(List ids, List embeddings, - @JsonProperty("metadatas") List> metadata, List documents) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record AddEmbeddingsRequest(// @formatter:off + @JsonProperty("ids") List ids, + @JsonProperty("embeddings") List embeddings, + @JsonProperty("metadatas") List> metadata, + @JsonProperty("documents") List documents) {// @formatter:on // Convenance for adding a single embedding. public AddEmbeddingsRequest(String id, float[] embedding, Map metadata, String document) { this(List.of(id), List.of(embedding), List.of(metadata), List.of(document)); } - } /** @@ -329,12 +329,14 @@ public AddEmbeddingsRequest(String id, float[] embedding, Map me * @param where Condition to filter items to delete based on metadata values. * (Optional) */ - public record DeleteEmbeddingsRequest(List ids, Map where) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record DeleteEmbeddingsRequest(// @formatter:off + @JsonProperty("ids") List ids, + @JsonProperty("where") Map where) {// @formatter:on public DeleteEmbeddingsRequest(List ids) { - this(ids, Map.of()); + this(ids, null); } - } /** @@ -348,19 +350,24 @@ public DeleteEmbeddingsRequest(List ids) { * "metadatas", "documents", "distances". Ids are always included. Defaults to * [metadatas, documents, distances]. */ - public record GetEmbeddingsRequest(List ids, Map where, int limit, int offset, - List include) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record GetEmbeddingsRequest(// @formatter:off + @JsonProperty("ids") List ids, + @JsonProperty("where") Map where, + @JsonProperty("limit") Integer limit, + @JsonProperty("offset") Integer offset, + @JsonProperty("include") List include) {// @formatter:on public GetEmbeddingsRequest(List ids) { - this(ids, Map.of(), 10, 0, Include.all); + this(ids, null, 10, 0, Include.all); } public GetEmbeddingsRequest(List ids, Map where) { - this(ids, where, 10, 0, Include.all); + this(ids, CollectionUtils.isEmpty(where) ? null : where, 10, 0, Include.all); } - public GetEmbeddingsRequest(List ids, Map where, int limit, int offset) { - this(ids, where, limit, offset, Include.all); + public GetEmbeddingsRequest(List ids, Map where, Integer limit, Integer offset) { + this(ids, CollectionUtils.isEmpty(where) ? null : where, limit, offset, Include.all); } } @@ -373,9 +380,12 @@ public GetEmbeddingsRequest(List ids, Map where, int lim * @param documents List of document contents. One for each returned document. * @param metadata List of document metadata. One for each returned document. */ - public record GetEmbeddingResponse(List ids, List embeddings, List documents, - @JsonProperty("metadatas") List> metadata) { - + @JsonInclude(JsonInclude.Include.NON_NULL) + public record GetEmbeddingResponse(// @formatter:off + @JsonProperty("ids") List ids, + @JsonProperty("embeddings") List embeddings, + @JsonProperty("documents") List documents, + @JsonProperty("metadatas") List> metadata) {// @formatter:on } /** @@ -390,18 +400,22 @@ public record GetEmbeddingResponse(List ids, List embeddings, L * "metadatas", "documents", "distances". Ids are always included. Defaults to * [metadatas, documents, distances]. */ - public record QueryRequest(@JsonProperty("query_embeddings") List queryEmbeddings, - @JsonProperty("n_results") int nResults, Map where, List include) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record QueryRequest( // @formatter:off + @JsonProperty("query_embeddings") List queryEmbeddings, + @JsonProperty("n_results") Integer nResults, + @JsonProperty("where") Map where, + @JsonProperty("include") List include) { // @formatter:on /** * Convenience to query for a single embedding instead of a batch of embeddings. */ - public QueryRequest(float[] queryEmbedding, int nResults) { - this(List.of(queryEmbedding), nResults, Map.of(), Include.all); + public QueryRequest(float[] queryEmbedding, Integer nResults) { + this(List.of(queryEmbedding), nResults, null, Include.all); } - public QueryRequest(float[] queryEmbedding, int nResults, Map where) { - this(List.of(queryEmbedding), nResults, where, Include.all); + public QueryRequest(float[] queryEmbedding, Integer nResults, Map where) { + this(List.of(queryEmbedding), nResults, CollectionUtils.isEmpty(where) ? null : where, Include.all); } public enum Include { @@ -434,9 +448,13 @@ public enum Include { * @param metadata List of list of document metadata. One for each returned document. * @param distances List of list of search distances. One for each returned document. */ - public record QueryResponse(List> ids, List> embeddings, List> documents, - @JsonProperty("metadatas") List>> metadata, List> distances) { - + @JsonInclude(JsonInclude.Include.NON_NULL) + public record QueryResponse(// @formatter:off + @JsonProperty("ids") List> ids, + @JsonProperty("embeddings") List> embeddings, + @JsonProperty("documents") List> documents, + @JsonProperty("metadatas") List>> metadata, + @JsonProperty("distances") List> distances) {// @formatter:on } /** @@ -448,8 +466,13 @@ public record QueryResponse(List> ids, List> embeddin * @param metadata The metadata of the document. * @param distances The distance of the document to the query embedding. */ - public record Embedding(String id, float[] embedding, String document, Map metadata, - Double distances) { + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Embedding(// @formatter:off + @JsonProperty("id") String id, + @JsonProperty("embedding") float[] embedding, + @JsonProperty("document") String document, + @JsonProperty("metadata") Map metadata, + @JsonProperty("distances") Double distances) {// @formatter:on } diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 1b83e9e7333..49deb147d65 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -43,9 +43,9 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.NonNull; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; /** * {@link ChromaVectorStore} is a concrete implementation of the {@link VectorStore} @@ -134,7 +134,7 @@ public void setFilterExpressionConverter(FilterExpressionConverter filterExpress } @Override - public void doAdd(List documents) { + public void doAdd(@NonNull List documents) { Assert.notNull(documents, "Documents must not be null"); if (CollectionUtils.isEmpty(documents)) { return; @@ -160,24 +160,23 @@ public void doAdd(List documents) { } @Override - public Optional doDelete(List idList) { + public Optional doDelete(@NonNull List idList) { Assert.notNull(idList, "Document id list must not be null"); int status = this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList)); return Optional.of(status == 200); } @Override - public List doSimilaritySearch(SearchRequest request) { - - String nativeFilterExpression = (request.getFilterExpression() != null) - ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : ""; + public @NonNull List doSimilaritySearch(@NonNull SearchRequest request) { String query = request.getQuery(); Assert.notNull(query, "Query string must not be null"); float[] embedding = this.embeddingModel.embed(query); - Map where = (StringUtils.hasText(nativeFilterExpression)) ? jsonToMap(nativeFilterExpression) - : Map.of(); + + Map where = (request.getFilterExpression() != null) + ? jsonToMap(this.filterExpressionConverter.convertExpression(request.getFilterExpression())) : null; + var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where); var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest); var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse); @@ -241,7 +240,8 @@ public void afterPropertiesSet() throws Exception { } @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + public @NonNull VectorStoreObservationContext.Builder createObservationContextBuilder( + @NonNull String operationName) { return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName) .withDimensions(this.embeddingModel.dimensions()) .withCollectionName(this.collectionName + ":" + this.collectionId) diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java index a2b56266bfc..c9961e5afb1 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java @@ -23,7 +23,7 @@ */ public final class ChromaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.16"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20"); private ChromaImage() { diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java index f24d8c62856..a2f6f000093 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java @@ -92,6 +92,34 @@ public void addAndSearch() { }); } + @Test + public void simpleSearch() { + this.contextRunner.run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + var document = Document.builder() + .withId("simpleDoc") + .withContent("The sky is blue because of Rayleigh scattering.") + .build(); + + vectorStore.add(List.of(document)); + + List results = vectorStore.similaritySearch("Why is the sky blue?"); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The sky is blue because of Rayleigh scattering."); + + // Remove all documents from the store + assertThat(vectorStore.delete(List.of(document.getId()))).isEqualTo(Optional.of(Boolean.TRUE)); + + results = vectorStore.similaritySearch(SearchRequest.query("Why is the sky blue?")); + assertThat(results).hasSize(0); + }); + } + @Test public void addAndSearchWithFilters() {