From c33fd4afdb8ef96549d67609ec3677862bd23ec0 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Wed, 20 Nov 2024 12:48:19 +0000 Subject: [PATCH] GH-1749 Fix Chroma query/delete operation - When Chroma query/delete operation doesn't involve a where clause, the latest Chroma API doesn't allow an empty map for the where parameter. This requires both the query and delete operations to remove explicit setting of where parameter with empty map. - Check the filter clause to query with or without `where` parameter - Update both query and delete operations - Update tests - Update the latest Chroma image 0.5.18 in the tests Resolves #1749 --- .../connection/chroma/ChromaImage.java | 2 +- .../springframework/ai/chroma/ChromaApi.java | 112 ++++++++++++++++-- .../ai/vectorstore/ChromaVectorStore.java | 14 ++- .../org/springframework/ai/ChromaImage.java | 2 +- .../ai/chroma/ChromaApiIT.java | 7 +- 5 files changed, 122 insertions(+), 15 deletions(-) diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java index 47491a9f80c..65918d174ef 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java @@ -23,7 +23,7 @@ */ public final class ChromaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.18"); private ChromaImage() { diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java index 743a4670c0c..871b6f25cd7 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -44,6 +44,7 @@ * * @author Christian Tzolov * @author EddĂș MelĂ©ndez + * @author Ilayaperumal Gopinathan */ public class ChromaApi { @@ -187,6 +188,17 @@ public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding .toBodilessEntity(); } + public int deleteEmbeddings(String collectionId, SimpleDeleteEmbeddingsRequest deleteRequest) { + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/delete", collectionId) + .headers(this::httpHeaders) + .body(deleteRequest) + .retrieve() + .toEntity(String.class) + .getStatusCode() + .value(); + } + public int deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteRequest) { return this.restClient.post() .uri("/api/v1/collections/{collection_id}/delete", collectionId) @@ -219,6 +231,17 @@ public QueryResponse queryCollection(String collectionId, QueryRequest queryRequ .getBody(); } + public QueryResponse simpleQueryCollection(String collectionId, SimpleQueryRequest queryRequest) { + + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/query", collectionId) + .headers(this::httpHeaders) + .body(queryRequest) + .retrieve() + .toEntity(QueryResponse.class) + .getBody(); + } + // // Chroma Client API (https://docs.trychroma.com/js_reference/Client) // @@ -234,6 +257,18 @@ public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequ .getBody(); } + public GetEmbeddingResponse getEmbeddings(String collectionId, + GetSimpleEmbeddingsRequest getSimpleEmbeddingsRequest) { + + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/get", collectionId) + .headers(this::httpHeaders) + .body(getSimpleEmbeddingsRequest) + .retrieve() + .toEntity(GetEmbeddingResponse.class) + .getBody(); + } + // Utils public Map where(String text) { try { @@ -331,9 +366,14 @@ public AddEmbeddingsRequest(String id, float[] embedding, Map me */ public record DeleteEmbeddingsRequest(List ids, Map where) { - public DeleteEmbeddingsRequest(List ids) { - this(ids, Map.of()); - } + } + + /** + * Request to delete embedding from a collection. + * + * @param ids The ids of the embeddings to delete. (Optional) + */ + public record SimpleDeleteEmbeddingsRequest(List ids) { } @@ -351,10 +391,6 @@ public DeleteEmbeddingsRequest(List ids) { public record GetEmbeddingsRequest(List ids, Map where, int limit, int offset, List include) { - public GetEmbeddingsRequest(List ids) { - this(ids, Map.of(), 10, 0, Include.all); - } - public GetEmbeddingsRequest(List ids, Map where) { this(ids, where, 10, 0, Include.all); } @@ -365,6 +401,27 @@ public GetEmbeddingsRequest(List ids, Map where, int lim } + /** + * Get embeddings from a collection. + * + * @param ids IDs of the embeddings to get. + * @param limit Limit on the number of collection embeddings to get. + * @param offset Offset on the embeddings to get. + * @param include A list of what to include in the results. Can contain "embeddings", + * "metadatas", "documents", "distances". Ids are always included. Defaults to + * [metadatas, documents, distances]. + */ + public record GetSimpleEmbeddingsRequest(List ids, int limit, int offset, List include) { + + public GetSimpleEmbeddingsRequest(List ids) { + this(ids, 10, 0, Include.all); + } + + public GetSimpleEmbeddingsRequest(List ids, Map where) { + this(ids, 10, 0, Include.all); + } + } + /** * Object containing the get embedding results. * @@ -424,6 +481,47 @@ public enum Include { } + /** + * Request to get the nResults nearest neighbor embeddings for provided + * queryEmbeddings. + * + * @param queryEmbeddings The embeddings to get the closes neighbors of. + * @param nResults The number of neighbors to return for each query_embedding or + * query_texts. + * @param include A list of what to include in the results. Can contain "embeddings", + * "metadatas", "documents", "distances". Ids are always included. Defaults to + * [metadatas, documents, distances]. + */ + public record SimpleQueryRequest(@JsonProperty("query_embeddings") List queryEmbeddings, + @JsonProperty("n_results") int nResults, List include) { + + /** + * Convenience to query for a single embedding instead of a batch of embeddings. + */ + public SimpleQueryRequest(float[] queryEmbedding, int nResults) { + this(List.of(queryEmbedding), nResults, Include.all); + } + + public enum Include { + + @JsonProperty("metadatas") + METADATAS, + + @JsonProperty("documents") + DOCUMENTS, + + @JsonProperty("distances") + DISTANCES, + + @JsonProperty("embeddings") + EMBEDDINGS; + + public static final List all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS); + + } + + } + /** * A QueryResponse object containing the query results. * diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 1b83e9e7333..5aa4d2ffd7e 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -162,7 +162,8 @@ public void doAdd(List documents) { @Override public Optional doDelete(List idList) { Assert.notNull(idList, "Document id list must not be null"); - int status = this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList)); + int status = this.chromaApi.deleteEmbeddings(this.collectionId, + new ChromaApi.SimpleDeleteEmbeddingsRequest(idList)); return Optional.of(status == 200); } @@ -178,8 +179,15 @@ public List doSimilaritySearch(SearchRequest request) { float[] embedding = this.embeddingModel.embed(query); Map where = (StringUtils.hasText(nativeFilterExpression)) ? jsonToMap(nativeFilterExpression) : Map.of(); - var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where); - var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest); + ChromaApi.QueryResponse queryResponse = null; + if (where.isEmpty()) { + queryResponse = this.chromaApi.simpleQueryCollection(this.collectionId, + new ChromaApi.SimpleQueryRequest(embedding, request.getTopK())); + } + else { + queryResponse = this.chromaApi.queryCollection(this.collectionId, + new ChromaApi.QueryRequest(embedding, request.getTopK(), where)); + } var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse); List responseDocuments = new ArrayList<>(); diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java index a2b56266bfc..10827782faf 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java @@ -23,7 +23,7 @@ */ public final class ChromaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.16"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.18"); private ChromaImage() { diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java index 0b01ae35486..fc957bfe038 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java @@ -128,7 +128,8 @@ public void testCollection() { this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World")); - var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); + var result = this.chromaApi.getEmbeddings(newCollection.id(), + new ChromaApi.GetSimpleEmbeddingsRequest((List.of("id2")))); assertThat(result.ids().get(0)).isEqualTo("id2"); queryResult = this.chromaApi.queryCollection(newCollection.id(), @@ -163,8 +164,8 @@ public void testQueryWhere() { assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3); - var queryResult = this.chromaApi.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); + var queryResult = this.chromaApi.simpleQueryCollection(collection.id(), + new ChromaApi.SimpleQueryRequest(new float[] { 1f, 1f, 1f }, 3)); assertThat(queryResult.ids().get(0)).hasSize(3); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3");