Skip to content

Commit c33fd4a

Browse files
committed
GH-1749 Fix Chroma query/delete operation
- When Chroma query/delete operation doesn't involve a where clause, the latest Chroma API doesn't allow an empty map for the where parameter. This requires both the query and delete operations to remove explicit setting of where parameter with empty map. - Check the filter clause to query with or without `where` parameter - Update both query and delete operations - Update tests - Update the latest Chroma image 0.5.18 in the tests Resolves #1749
1 parent 551206f commit c33fd4a

File tree

5 files changed

+122
-15
lines changed

5 files changed

+122
-15
lines changed

spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*/
2424
public final class ChromaImage {
2525

26-
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11");
26+
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.18");
2727

2828
private ChromaImage() {
2929

vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
*
4545
* @author Christian Tzolov
4646
* @author Eddú Meléndez
47+
* @author Ilayaperumal Gopinathan
4748
*/
4849
public class ChromaApi {
4950

@@ -187,6 +188,17 @@ public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding
187188
.toBodilessEntity();
188189
}
189190

191+
public int deleteEmbeddings(String collectionId, SimpleDeleteEmbeddingsRequest deleteRequest) {
192+
return this.restClient.post()
193+
.uri("/api/v1/collections/{collection_id}/delete", collectionId)
194+
.headers(this::httpHeaders)
195+
.body(deleteRequest)
196+
.retrieve()
197+
.toEntity(String.class)
198+
.getStatusCode()
199+
.value();
200+
}
201+
190202
public int deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteRequest) {
191203
return this.restClient.post()
192204
.uri("/api/v1/collections/{collection_id}/delete", collectionId)
@@ -219,6 +231,17 @@ public QueryResponse queryCollection(String collectionId, QueryRequest queryRequ
219231
.getBody();
220232
}
221233

234+
public QueryResponse simpleQueryCollection(String collectionId, SimpleQueryRequest queryRequest) {
235+
236+
return this.restClient.post()
237+
.uri("/api/v1/collections/{collection_id}/query", collectionId)
238+
.headers(this::httpHeaders)
239+
.body(queryRequest)
240+
.retrieve()
241+
.toEntity(QueryResponse.class)
242+
.getBody();
243+
}
244+
222245
//
223246
// Chroma Client API (https://docs.trychroma.com/js_reference/Client)
224247
//
@@ -234,6 +257,18 @@ public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequ
234257
.getBody();
235258
}
236259

260+
public GetEmbeddingResponse getEmbeddings(String collectionId,
261+
GetSimpleEmbeddingsRequest getSimpleEmbeddingsRequest) {
262+
263+
return this.restClient.post()
264+
.uri("/api/v1/collections/{collection_id}/get", collectionId)
265+
.headers(this::httpHeaders)
266+
.body(getSimpleEmbeddingsRequest)
267+
.retrieve()
268+
.toEntity(GetEmbeddingResponse.class)
269+
.getBody();
270+
}
271+
237272
// Utils
238273
public Map<String, Object> where(String text) {
239274
try {
@@ -331,9 +366,14 @@ public AddEmbeddingsRequest(String id, float[] embedding, Map<String, Object> me
331366
*/
332367
public record DeleteEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
333368

334-
public DeleteEmbeddingsRequest(List<String> ids) {
335-
this(ids, Map.of());
336-
}
369+
}
370+
371+
/**
372+
* Request to delete embedding from a collection.
373+
*
374+
* @param ids The ids of the embeddings to delete. (Optional)
375+
*/
376+
public record SimpleDeleteEmbeddingsRequest(List<String> ids) {
337377

338378
}
339379

@@ -351,10 +391,6 @@ public DeleteEmbeddingsRequest(List<String> ids) {
351391
public record GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int limit, int offset,
352392
List<Include> include) {
353393

354-
public GetEmbeddingsRequest(List<String> ids) {
355-
this(ids, Map.of(), 10, 0, Include.all);
356-
}
357-
358394
public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
359395
this(ids, where, 10, 0, Include.all);
360396
}
@@ -365,6 +401,27 @@ public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int lim
365401

366402
}
367403

404+
/**
405+
* Get embeddings from a collection.
406+
*
407+
* @param ids IDs of the embeddings to get.
408+
* @param limit Limit on the number of collection embeddings to get.
409+
* @param offset Offset on the embeddings to get.
410+
* @param include A list of what to include in the results. Can contain "embeddings",
411+
* "metadatas", "documents", "distances". Ids are always included. Defaults to
412+
* [metadatas, documents, distances].
413+
*/
414+
public record GetSimpleEmbeddingsRequest(List<String> ids, int limit, int offset, List<Include> include) {
415+
416+
public GetSimpleEmbeddingsRequest(List<String> ids) {
417+
this(ids, 10, 0, Include.all);
418+
}
419+
420+
public GetSimpleEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
421+
this(ids, 10, 0, Include.all);
422+
}
423+
}
424+
368425
/**
369426
* Object containing the get embedding results.
370427
*
@@ -424,6 +481,47 @@ public enum Include {
424481

425482
}
426483

484+
/**
485+
* Request to get the nResults nearest neighbor embeddings for provided
486+
* queryEmbeddings.
487+
*
488+
* @param queryEmbeddings The embeddings to get the closes neighbors of.
489+
* @param nResults The number of neighbors to return for each query_embedding or
490+
* query_texts.
491+
* @param include A list of what to include in the results. Can contain "embeddings",
492+
* "metadatas", "documents", "distances". Ids are always included. Defaults to
493+
* [metadatas, documents, distances].
494+
*/
495+
public record SimpleQueryRequest(@JsonProperty("query_embeddings") List<float[]> queryEmbeddings,
496+
@JsonProperty("n_results") int nResults, List<Include> include) {
497+
498+
/**
499+
* Convenience to query for a single embedding instead of a batch of embeddings.
500+
*/
501+
public SimpleQueryRequest(float[] queryEmbedding, int nResults) {
502+
this(List.of(queryEmbedding), nResults, Include.all);
503+
}
504+
505+
public enum Include {
506+
507+
@JsonProperty("metadatas")
508+
METADATAS,
509+
510+
@JsonProperty("documents")
511+
DOCUMENTS,
512+
513+
@JsonProperty("distances")
514+
DISTANCES,
515+
516+
@JsonProperty("embeddings")
517+
EMBEDDINGS;
518+
519+
public static final List<Include> all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS);
520+
521+
}
522+
523+
}
524+
427525
/**
428526
* A QueryResponse object containing the query results.
429527
*

vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ public void doAdd(List<Document> documents) {
162162
@Override
163163
public Optional<Boolean> doDelete(List<String> idList) {
164164
Assert.notNull(idList, "Document id list must not be null");
165-
int status = this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList));
165+
int status = this.chromaApi.deleteEmbeddings(this.collectionId,
166+
new ChromaApi.SimpleDeleteEmbeddingsRequest(idList));
166167
return Optional.of(status == 200);
167168
}
168169

@@ -178,8 +179,15 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
178179
float[] embedding = this.embeddingModel.embed(query);
179180
Map<String, Object> where = (StringUtils.hasText(nativeFilterExpression)) ? jsonToMap(nativeFilterExpression)
180181
: Map.of();
181-
var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where);
182-
var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest);
182+
ChromaApi.QueryResponse queryResponse = null;
183+
if (where.isEmpty()) {
184+
queryResponse = this.chromaApi.simpleQueryCollection(this.collectionId,
185+
new ChromaApi.SimpleQueryRequest(embedding, request.getTopK()));
186+
}
187+
else {
188+
queryResponse = this.chromaApi.queryCollection(this.collectionId,
189+
new ChromaApi.QueryRequest(embedding, request.getTopK(), where));
190+
}
183191
var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse);
184192

185193
List<Document> responseDocuments = new ArrayList<>();

vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*/
2424
public final class ChromaImage {
2525

26-
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.16");
26+
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.18");
2727

2828
private ChromaImage() {
2929

vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ public void testCollection() {
128128
this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f },
129129
Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World"));
130130

131-
var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2")));
131+
var result = this.chromaApi.getEmbeddings(newCollection.id(),
132+
new ChromaApi.GetSimpleEmbeddingsRequest((List.of("id2"))));
132133
assertThat(result.ids().get(0)).isEqualTo("id2");
133134

134135
queryResult = this.chromaApi.queryCollection(newCollection.id(),
@@ -163,8 +164,8 @@ public void testQueryWhere() {
163164

164165
assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3);
165166

166-
var queryResult = this.chromaApi.queryCollection(collection.id(),
167-
new QueryRequest(new float[] { 1f, 1f, 1f }, 3));
167+
var queryResult = this.chromaApi.simpleQueryCollection(collection.id(),
168+
new ChromaApi.SimpleQueryRequest(new float[] { 1f, 1f, 1f }, 3));
168169

169170
assertThat(queryResult.ids().get(0)).hasSize(3);
170171
assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3");

0 commit comments

Comments
 (0)