Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
public final class ChromaImage {

public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11");
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.18");

private ChromaImage() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
*
* @author Christian Tzolov
* @author Eddú Meléndez
* @author Ilayaperumal Gopinathan
*/
public class ChromaApi {

Expand Down Expand Up @@ -187,6 +188,17 @@ public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding
.toBodilessEntity();
}

public int deleteEmbeddings(String collectionId, SimpleDeleteEmbeddingsRequest deleteRequest) {
return this.restClient.post()
.uri("/api/v1/collections/{collection_id}/delete", collectionId)
.headers(this::httpHeaders)
.body(deleteRequest)
.retrieve()
.toEntity(String.class)
.getStatusCode()
.value();
}

public int deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteRequest) {
return this.restClient.post()
.uri("/api/v1/collections/{collection_id}/delete", collectionId)
Expand Down Expand Up @@ -219,6 +231,17 @@ public QueryResponse queryCollection(String collectionId, QueryRequest queryRequ
.getBody();
}

public QueryResponse simpleQueryCollection(String collectionId, SimpleQueryRequest queryRequest) {

return this.restClient.post()
.uri("/api/v1/collections/{collection_id}/query", collectionId)
.headers(this::httpHeaders)
.body(queryRequest)
.retrieve()
.toEntity(QueryResponse.class)
.getBody();
}

//
// Chroma Client API (https://docs.trychroma.com/js_reference/Client)
//
Expand All @@ -234,6 +257,18 @@ public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequ
.getBody();
}

public GetEmbeddingResponse getEmbeddings(String collectionId,
GetSimpleEmbeddingsRequest getSimpleEmbeddingsRequest) {

return this.restClient.post()
.uri("/api/v1/collections/{collection_id}/get", collectionId)
.headers(this::httpHeaders)
.body(getSimpleEmbeddingsRequest)
.retrieve()
.toEntity(GetEmbeddingResponse.class)
.getBody();
}

// Utils
public Map<String, Object> where(String text) {
try {
Expand Down Expand Up @@ -331,9 +366,14 @@ public AddEmbeddingsRequest(String id, float[] embedding, Map<String, Object> me
*/
public record DeleteEmbeddingsRequest(List<String> ids, Map<String, Object> where) {

public DeleteEmbeddingsRequest(List<String> ids) {
this(ids, Map.of());
}
}

/**
* Request to delete embedding from a collection.
*
* @param ids The ids of the embeddings to delete. (Optional)
*/
public record SimpleDeleteEmbeddingsRequest(List<String> ids) {

}

Expand All @@ -351,10 +391,6 @@ public DeleteEmbeddingsRequest(List<String> ids) {
public record GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int limit, int offset,
List<Include> include) {

public GetEmbeddingsRequest(List<String> ids) {
this(ids, Map.of(), 10, 0, Include.all);
}

public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
this(ids, where, 10, 0, Include.all);
}
Expand All @@ -365,6 +401,27 @@ public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int lim

}

/**
* Get embeddings from a collection.
*
* @param ids IDs of the embeddings to get.
* @param limit Limit on the number of collection embeddings to get.
* @param offset Offset on the embeddings to get.
* @param include A list of what to include in the results. Can contain "embeddings",
* "metadatas", "documents", "distances". Ids are always included. Defaults to
* [metadatas, documents, distances].
*/
public record GetSimpleEmbeddingsRequest(List<String> ids, int limit, int offset, List<Include> include) {

public GetSimpleEmbeddingsRequest(List<String> ids) {
this(ids, 10, 0, Include.all);
}

public GetSimpleEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
this(ids, 10, 0, Include.all);
}
}

/**
* Object containing the get embedding results.
*
Expand Down Expand Up @@ -424,6 +481,47 @@ public enum Include {

}

/**
* Request to get the nResults nearest neighbor embeddings for provided
* queryEmbeddings.
*
* @param queryEmbeddings The embeddings to get the closes neighbors of.
* @param nResults The number of neighbors to return for each query_embedding or
* query_texts.
* @param include A list of what to include in the results. Can contain "embeddings",
* "metadatas", "documents", "distances". Ids are always included. Defaults to
* [metadatas, documents, distances].
*/
public record SimpleQueryRequest(@JsonProperty("query_embeddings") List<float[]> queryEmbeddings,
@JsonProperty("n_results") int nResults, List<Include> include) {

/**
* Convenience to query for a single embedding instead of a batch of embeddings.
*/
public SimpleQueryRequest(float[] queryEmbedding, int nResults) {
this(List.of(queryEmbedding), nResults, Include.all);
}

public enum Include {

@JsonProperty("metadatas")
METADATAS,

@JsonProperty("documents")
DOCUMENTS,

@JsonProperty("distances")
DISTANCES,

@JsonProperty("embeddings")
EMBEDDINGS;

public static final List<Include> all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS);

}

}

/**
* A QueryResponse object containing the query results.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ public void doAdd(List<Document> documents) {
@Override
public Optional<Boolean> doDelete(List<String> idList) {
Assert.notNull(idList, "Document id list must not be null");
int status = this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList));
int status = this.chromaApi.deleteEmbeddings(this.collectionId,
new ChromaApi.SimpleDeleteEmbeddingsRequest(idList));
return Optional.of(status == 200);
}

Expand All @@ -178,8 +179,15 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
float[] embedding = this.embeddingModel.embed(query);
Map<String, Object> where = (StringUtils.hasText(nativeFilterExpression)) ? jsonToMap(nativeFilterExpression)
: Map.of();
var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where);
var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest);
ChromaApi.QueryResponse queryResponse = null;
if (where.isEmpty()) {
queryResponse = this.chromaApi.simpleQueryCollection(this.collectionId,
new ChromaApi.SimpleQueryRequest(embedding, request.getTopK()));
}
else {
queryResponse = this.chromaApi.queryCollection(this.collectionId,
new ChromaApi.QueryRequest(embedding, request.getTopK(), where));
}
var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse);

List<Document> responseDocuments = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
public final class ChromaImage {

public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.16");
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.18");

private ChromaImage() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ public void testCollection() {
this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f },
Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World"));

var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2")));
var result = this.chromaApi.getEmbeddings(newCollection.id(),
new ChromaApi.GetSimpleEmbeddingsRequest((List.of("id2"))));
assertThat(result.ids().get(0)).isEqualTo("id2");

queryResult = this.chromaApi.queryCollection(newCollection.id(),
Expand Down Expand Up @@ -163,8 +164,8 @@ public void testQueryWhere() {

assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3);

var queryResult = this.chromaApi.queryCollection(collection.id(),
new QueryRequest(new float[] { 1f, 1f, 1f }, 3));
var queryResult = this.chromaApi.simpleQueryCollection(collection.id(),
new ChromaApi.SimpleQueryRequest(new float[] { 1f, 1f, 1f }, 3));

assertThat(queryResult.ids().get(0)).hasSize(3);
assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id2", "id3");
Expand Down