diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java index 47491a9f80c..65918d174ef 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java @@ -23,7 +23,7 @@ */ public final class ChromaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.18"); private ChromaImage() { diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java index 743a4670c0c..871b6f25cd7 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -44,6 +44,7 @@ * * @author Christian Tzolov * @author EddĂș MelĂ©ndez + * @author Ilayaperumal Gopinathan */ public class ChromaApi { @@ -187,6 +188,17 @@ public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding .toBodilessEntity(); } + public int deleteEmbeddings(String collectionId, SimpleDeleteEmbeddingsRequest deleteRequest) { + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/delete", collectionId) + .headers(this::httpHeaders) + .body(deleteRequest) + .retrieve() + .toEntity(String.class) + .getStatusCode() + .value(); + } + public int deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteRequest) { return this.restClient.post() .uri("/api/v1/collections/{collection_id}/delete", collectionId) @@ -219,6 +231,17 @@ public QueryResponse queryCollection(String collectionId, QueryRequest queryRequ .getBody(); } + public QueryResponse simpleQueryCollection(String collectionId, SimpleQueryRequest queryRequest) { + + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/query", collectionId) + .headers(this::httpHeaders) + .body(queryRequest) + .retrieve() + .toEntity(QueryResponse.class) + .getBody(); + } + // // Chroma Client API (https://docs.trychroma.com/js_reference/Client) // @@ -234,6 +257,18 @@ public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequ .getBody(); } + public GetEmbeddingResponse getEmbeddings(String collectionId, + GetSimpleEmbeddingsRequest getSimpleEmbeddingsRequest) { + + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/get", collectionId) + .headers(this::httpHeaders) + .body(getSimpleEmbeddingsRequest) + .retrieve() + .toEntity(GetEmbeddingResponse.class) + .getBody(); + } + // Utils public Map where(String text) { try { @@ -331,9 +366,14 @@ public AddEmbeddingsRequest(String id, float[] embedding, Map me */ public record DeleteEmbeddingsRequest(List ids, Map where) { - public DeleteEmbeddingsRequest(List ids) { - this(ids, Map.of()); - } + } + + /** + * Request to delete embedding from a collection. + * + * @param ids The ids of the embeddings to delete. (Optional) + */ + public record SimpleDeleteEmbeddingsRequest(List ids) { } @@ -351,10 +391,6 @@ public DeleteEmbeddingsRequest(List ids) { public record GetEmbeddingsRequest(List ids, Map where, int limit, int offset, List include) { - public GetEmbeddingsRequest(List ids) { - this(ids, Map.of(), 10, 0, Include.all); - } - public GetEmbeddingsRequest(List ids, Map where) { this(ids, where, 10, 0, Include.all); } @@ -365,6 +401,27 @@ public GetEmbeddingsRequest(List ids, Map where, int lim } + /** + * Get embeddings from a collection. + * + * @param ids IDs of the embeddings to get. + * @param limit Limit on the number of collection embeddings to get. + * @param offset Offset on the embeddings to get. + * @param include A list of what to include in the results. Can contain "embeddings", + * "metadatas", "documents", "distances". Ids are always included. Defaults to + * [metadatas, documents, distances]. + */ + public record GetSimpleEmbeddingsRequest(List ids, int limit, int offset, List include) { + + public GetSimpleEmbeddingsRequest(List ids) { + this(ids, 10, 0, Include.all); + } + + public GetSimpleEmbeddingsRequest(List ids, Map where) { + this(ids, 10, 0, Include.all); + } + } + /** * Object containing the get embedding results. * @@ -424,6 +481,47 @@ public enum Include { } + /** + * Request to get the nResults nearest neighbor embeddings for provided + * queryEmbeddings. + * + * @param queryEmbeddings The embeddings to get the closes neighbors of. + * @param nResults The number of neighbors to return for each query_embedding or + * query_texts. + * @param include A list of what to include in the results. Can contain "embeddings", + * "metadatas", "documents", "distances". Ids are always included. Defaults to + * [metadatas, documents, distances]. + */ + public record SimpleQueryRequest(@JsonProperty("query_embeddings") List queryEmbeddings, + @JsonProperty("n_results") int nResults, List include) { + + /** + * Convenience to query for a single embedding instead of a batch of embeddings. + */ + public SimpleQueryRequest(float[] queryEmbedding, int nResults) { + this(List.of(queryEmbedding), nResults, Include.all); + } + + public enum Include { + + @JsonProperty("metadatas") + METADATAS, + + @JsonProperty("documents") + DOCUMENTS, + + @JsonProperty("distances") + DISTANCES, + + @JsonProperty("embeddings") + EMBEDDINGS; + + public static final List all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS); + + } + + } + /** * A QueryResponse object containing the query results. * diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 1b83e9e7333..5aa4d2ffd7e 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -162,7 +162,8 @@ public void doAdd(List documents) { @Override public Optional doDelete(List idList) { Assert.notNull(idList, "Document id list must not be null"); - int status = this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList)); + int status = this.chromaApi.deleteEmbeddings(this.collectionId, + new ChromaApi.SimpleDeleteEmbeddingsRequest(idList)); return Optional.of(status == 200); } @@ -178,8 +179,15 @@ public List doSimilaritySearch(SearchRequest request) { float[] embedding = this.embeddingModel.embed(query); Map where = (StringUtils.hasText(nativeFilterExpression)) ? jsonToMap(nativeFilterExpression) : Map.of(); - var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where); - var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest); + ChromaApi.QueryResponse queryResponse = null; + if (where.isEmpty()) { + queryResponse = this.chromaApi.simpleQueryCollection(this.collectionId, + new ChromaApi.SimpleQueryRequest(embedding, request.getTopK())); + } + else { + queryResponse = this.chromaApi.queryCollection(this.collectionId, + new ChromaApi.QueryRequest(embedding, request.getTopK(), where)); + } var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse); List responseDocuments = new ArrayList<>(); diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java index a2b56266bfc..10827782faf 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java @@ -23,7 +23,7 @@ */ public final class ChromaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.16"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.18"); private ChromaImage() { diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java index 0b01ae35486..fc957bfe038 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java @@ -128,7 +128,8 @@ public void testCollection() { this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World")); - var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); + var result = this.chromaApi.getEmbeddings(newCollection.id(), + new ChromaApi.GetSimpleEmbeddingsRequest((List.of("id2")))); assertThat(result.ids().get(0)).isEqualTo("id2"); queryResult = this.chromaApi.queryCollection(newCollection.id(), @@ -163,8 +164,8 @@ public void testQueryWhere() { assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3); - var queryResult = this.chromaApi.queryCollection(collection.id(), - new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); + var queryResult = this.chromaApi.simpleQueryCollection(collection.id(), + new ChromaApi.SimpleQueryRequest(new float[] { 1f, 1f, 1f }, 3)); assertThat(queryResult.ids().get(0)).hasSize(3); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3");