Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ If all goes well, you should retrieve the document containing the text "Spring A
=== Run Chroma Locally

```shell
docker run -it --rm --name chroma -p 8000:8000 ghcr.io/chroma-core/chroma:0.4.15
docker run -it --rm --name chroma -p 8000:8000 ghcr.io/chroma-core/chroma:0.5.20
```

Starts a chroma store at <http://localhost:8000/api/v1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
public class ChromaVectorStoreAutoConfigurationIT {

@Container
static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.0");
static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.20");

private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(ChromaVectorStoreAutoConfiguration.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
public final class ChromaImage {

public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11");
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20");

private ChromaImage() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -123,17 +124,6 @@ public Collection createCollection(CreateCollectionRequest createCollectionReque
.getBody();
}

public Map<String, Object> createCollection2(CreateCollectionRequest createCollectionRequest) {

return this.restClient.post()
.uri("/api/v1/collections")
.headers(this::httpHeaders)
.body(createCollectionRequest)
.retrieve()
.toEntity(Map.class)
.getBody();
}

/**
* Delete a collection with the given name.
* @param collectionName the name of the collection to delete.
Expand Down Expand Up @@ -281,7 +271,11 @@ private String getErrorMessage(HttpStatusCodeException e) {
* @param name The name of the collection.
* @param metadata Metadata associated with the collection.
*/
public record Collection(String id, String name, Map<String, Object> metadata) {
@JsonInclude(JsonInclude.Include.NON_NULL)
public record Collection(// @formatter:off
@JsonProperty("id") String id,
@JsonProperty("name") String name,
@JsonProperty("metadata") Map<String, Object> metadata) { // @formatter:on

}

Expand All @@ -291,7 +285,10 @@ public record Collection(String id, String name, Map<String, Object> metadata) {
* @param name The name of the collection to create.
* @param metadata Optional metadata to associate with the collection.
*/
public record CreateCollectionRequest(String name, Map<String, Object> metadata) {
@JsonInclude(JsonInclude.Include.NON_NULL)
public record CreateCollectionRequest(// @formatter:off
@JsonProperty("name") String name,
@JsonProperty("metadata") Map<String, Object> metadata) {// @formatter:on

public CreateCollectionRequest(String name) {
this(name, new HashMap<>(Map.of("hnsw:space", "cosine")));
Expand All @@ -300,7 +297,7 @@ public CreateCollectionRequest(String name) {
}

//
// Chroma Collection API (https://docs.trychroma.com/js_reference/Collection)
// Chroma Collection API (https://docs.trychroma.com/reference/js-client/Collection)
//

/**
Expand All @@ -312,14 +309,17 @@ public CreateCollectionRequest(String name) {
* can filter on this metadata.
* @param documents The documents contents to associate with the embeddings.
*/
public record AddEmbeddingsRequest(List<String> ids, List<float[]> embeddings,
@JsonProperty("metadatas") List<Map<String, Object>> metadata, List<String> documents) {
@JsonInclude(JsonInclude.Include.NON_NULL)
public record AddEmbeddingsRequest(// @formatter:off
@JsonProperty("ids") List<String> ids,
@JsonProperty("embeddings") List<float[]> embeddings,
@JsonProperty("metadatas") List<Map<String, Object>> metadata,
@JsonProperty("documents") List<String> documents) {// @formatter:on

// Convenance for adding a single embedding.
public AddEmbeddingsRequest(String id, float[] embedding, Map<String, Object> metadata, String document) {
this(List.of(id), List.of(embedding), List.of(metadata), List.of(document));
}

}

/**
Expand All @@ -329,12 +329,14 @@ public AddEmbeddingsRequest(String id, float[] embedding, Map<String, Object> me
* @param where Condition to filter items to delete based on metadata values.
* (Optional)
*/
public record DeleteEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
@JsonInclude(JsonInclude.Include.NON_NULL)
public record DeleteEmbeddingsRequest(// @formatter:off
@JsonProperty("ids") List<String> ids,
@JsonProperty("where") Map<String, Object> where) {// @formatter:on

public DeleteEmbeddingsRequest(List<String> ids) {
this(ids, Map.of());
this(ids, null);
}

}

/**
Expand All @@ -348,19 +350,24 @@ public DeleteEmbeddingsRequest(List<String> ids) {
* "metadatas", "documents", "distances". Ids are always included. Defaults to
* [metadatas, documents, distances].
*/
public record GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int limit, int offset,
List<Include> include) {
@JsonInclude(JsonInclude.Include.NON_NULL)
public record GetEmbeddingsRequest(// @formatter:off
@JsonProperty("ids") List<String> ids,
@JsonProperty("where") Map<String, Object> where,
@JsonProperty("limit") Integer limit,
@JsonProperty("offset") Integer offset,
@JsonProperty("include") List<Include> include) {// @formatter:on

public GetEmbeddingsRequest(List<String> ids) {
this(ids, Map.of(), 10, 0, Include.all);
this(ids, null, 10, 0, Include.all);
}

public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
this(ids, where, 10, 0, Include.all);
this(ids, CollectionUtils.isEmpty(where) ? null : where, 10, 0, Include.all);
}

public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int limit, int offset) {
this(ids, where, limit, offset, Include.all);
public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, Integer limit, Integer offset) {
this(ids, CollectionUtils.isEmpty(where) ? null : where, limit, offset, Include.all);
}

}
Expand All @@ -373,9 +380,12 @@ public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int lim
* @param documents List of document contents. One for each returned document.
* @param metadata List of document metadata. One for each returned document.
*/
public record GetEmbeddingResponse(List<String> ids, List<float[]> embeddings, List<String> documents,
@JsonProperty("metadatas") List<Map<String, String>> metadata) {

@JsonInclude(JsonInclude.Include.NON_NULL)
public record GetEmbeddingResponse(// @formatter:off
@JsonProperty("ids") List<String> ids,
@JsonProperty("embeddings") List<float[]> embeddings,
@JsonProperty("documents") List<String> documents,
@JsonProperty("metadatas") List<Map<String, String>> metadata) {// @formatter:on
}

/**
Expand All @@ -390,18 +400,22 @@ public record GetEmbeddingResponse(List<String> ids, List<float[]> embeddings, L
* "metadatas", "documents", "distances". Ids are always included. Defaults to
* [metadatas, documents, distances].
*/
public record QueryRequest(@JsonProperty("query_embeddings") List<float[]> queryEmbeddings,
@JsonProperty("n_results") int nResults, Map<String, Object> where, List<Include> include) {
@JsonInclude(JsonInclude.Include.NON_NULL)
public record QueryRequest( // @formatter:off
@JsonProperty("query_embeddings") List<float[]> queryEmbeddings,
@JsonProperty("n_results") Integer nResults,
@JsonProperty("where") Map<String, Object> where,
@JsonProperty("include") List<Include> include) { // @formatter:on

/**
* Convenience to query for a single embedding instead of a batch of embeddings.
*/
public QueryRequest(float[] queryEmbedding, int nResults) {
this(List.of(queryEmbedding), nResults, Map.of(), Include.all);
public QueryRequest(float[] queryEmbedding, Integer nResults) {
this(List.of(queryEmbedding), nResults, null, Include.all);
}

public QueryRequest(float[] queryEmbedding, int nResults, Map<String, Object> where) {
this(List.of(queryEmbedding), nResults, where, Include.all);
public QueryRequest(float[] queryEmbedding, Integer nResults, Map<String, Object> where) {
this(List.of(queryEmbedding), nResults, CollectionUtils.isEmpty(where) ? null : where, Include.all);
}

public enum Include {
Expand Down Expand Up @@ -434,9 +448,13 @@ public enum Include {
* @param metadata List of list of document metadata. One for each returned document.
* @param distances List of list of search distances. One for each returned document.
*/
public record QueryResponse(List<List<String>> ids, List<List<float[]>> embeddings, List<List<String>> documents,
@JsonProperty("metadatas") List<List<Map<String, Object>>> metadata, List<List<Double>> distances) {

@JsonInclude(JsonInclude.Include.NON_NULL)
public record QueryResponse(// @formatter:off
@JsonProperty("ids") List<List<String>> ids,
@JsonProperty("embeddings") List<List<float[]>> embeddings,
@JsonProperty("documents") List<List<String>> documents,
@JsonProperty("metadatas") List<List<Map<String, Object>>> metadata,
@JsonProperty("distances") List<List<Double>> distances) {// @formatter:on
}

/**
Expand All @@ -448,8 +466,13 @@ public record QueryResponse(List<List<String>> ids, List<List<float[]>> embeddin
* @param metadata The metadata of the document.
* @param distances The distance of the document to the query embedding.
*/
public record Embedding(String id, float[] embedding, String document, Map<String, Object> metadata,
Double distances) {
@JsonInclude(JsonInclude.Include.NON_NULL)
public record Embedding(// @formatter:off
@JsonProperty("id") String id,
@JsonProperty("embedding") float[] embedding,
@JsonProperty("document") String document,
@JsonProperty("metadata") Map<String, Object> metadata,
@JsonProperty("distances") Double distances) {// @formatter:on

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
* {@link ChromaVectorStore} is a concrete implementation of the {@link VectorStore}
Expand Down Expand Up @@ -134,7 +134,7 @@ public void setFilterExpressionConverter(FilterExpressionConverter filterExpress
}

@Override
public void doAdd(List<Document> documents) {
public void doAdd(@NonNull List<Document> documents) {
Assert.notNull(documents, "Documents must not be null");
if (CollectionUtils.isEmpty(documents)) {
return;
Expand All @@ -160,24 +160,23 @@ public void doAdd(List<Document> documents) {
}

@Override
public Optional<Boolean> doDelete(List<String> idList) {
public Optional<Boolean> doDelete(@NonNull List<String> idList) {
Assert.notNull(idList, "Document id list must not be null");
int status = this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList));
return Optional.of(status == 200);
}

@Override
public List<Document> doSimilaritySearch(SearchRequest request) {

String nativeFilterExpression = (request.getFilterExpression() != null)
? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
public @NonNull List<Document> doSimilaritySearch(@NonNull SearchRequest request) {

String query = request.getQuery();
Assert.notNull(query, "Query string must not be null");

float[] embedding = this.embeddingModel.embed(query);
Map<String, Object> where = (StringUtils.hasText(nativeFilterExpression)) ? jsonToMap(nativeFilterExpression)
: Map.of();

Map<String, Object> where = (request.getFilterExpression() != null)
? jsonToMap(this.filterExpressionConverter.convertExpression(request.getFilterExpression())) : null;

var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where);
var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest);
var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse);
Expand Down Expand Up @@ -241,7 +240,8 @@ public void afterPropertiesSet() throws Exception {
}

@Override
public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
public @NonNull VectorStoreObservationContext.Builder createObservationContextBuilder(
@NonNull String operationName) {
return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName)
.withDimensions(this.embeddingModel.dimensions())
.withCollectionName(this.collectionName + ":" + this.collectionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
public final class ChromaImage {

public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.16");
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20");

private ChromaImage() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,34 @@ public void addAndSearch() {
});
}

@Test
public void simpleSearch() {
this.contextRunner.run(context -> {

VectorStore vectorStore = context.getBean(VectorStore.class);

var document = Document.builder()
.withId("simpleDoc")
.withContent("The sky is blue because of Rayleigh scattering.")
.build();

vectorStore.add(List.of(document));

List<Document> results = vectorStore.similaritySearch("Why is the sky blue?");

assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The sky is blue because of Rayleigh scattering.");

// Remove all documents from the store
assertThat(vectorStore.delete(List.of(document.getId()))).isEqualTo(Optional.of(Boolean.TRUE));

results = vectorStore.similaritySearch(SearchRequest.query("Why is the sky blue?"));
assertThat(results).hasSize(0);
});
}

@Test
public void addAndSearchWithFilters() {

Expand Down