diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/pom.xml b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/pom.xml
index a852dc773fe..9eee24b3efb 100644
--- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/pom.xml
+++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/pom.xml
@@ -94,6 +94,7 @@
org.testcontainers
chromadb
+ 1.21.0
test
diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfiguration.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfiguration.java
index 6f96c5aaa4f..4af9a68728b 100644
--- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfiguration.java
+++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfiguration.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023-2024 the original author or authors.
+ * Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -65,8 +65,11 @@ public ChromaApi chromaApi(ChromaApiProperties apiProperties,
String chromaUrl = String.format("%s:%s", connectionDetails.getHost(), connectionDetails.getPort());
- var chromaApi = new ChromaApi(chromaUrl, restClientBuilderProvider.getIfAvailable(RestClient::builder),
- objectMapper);
+ var chromaApi = ChromaApi.builder()
+ .baseUrl(chromaUrl)
+ .restClientBuilder(restClientBuilderProvider.getIfAvailable(RestClient::builder))
+ .objectMapper(objectMapper)
+ .build();
if (StringUtils.hasText(connectionDetails.getKeyToken())) {
chromaApi.withKeyToken(connectionDetails.getKeyToken());
diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreProperties.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreProperties.java
index 3098eb4e5e2..92931fe83a0 100644
--- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreProperties.java
+++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreProperties.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023-2024 the original author or authors.
+ * Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
package org.springframework.ai.vectorstore.chroma.autoconfigure;
-import org.springframework.ai.chroma.vectorstore.ChromaVectorStore;
+import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants;
import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
@@ -25,13 +25,34 @@
*
* @author Christian Tzolov
* @author Soby Chacko
+ * @author Jonghoon Park
*/
@ConfigurationProperties(ChromaVectorStoreProperties.CONFIG_PREFIX)
public class ChromaVectorStoreProperties extends CommonVectorStoreProperties {
public static final String CONFIG_PREFIX = "spring.ai.vectorstore.chroma";
- private String collectionName = ChromaVectorStore.DEFAULT_COLLECTION_NAME;
+ private String tenantName = ChromaApiConstants.DEFAULT_TENANT_NAME;
+
+ private String databaseName = ChromaApiConstants.DEFAULT_DATABASE_NAME;
+
+ private String collectionName = ChromaApiConstants.DEFAULT_COLLECTION_NAME;
+
+ public String getTenantName() {
+ return tenantName;
+ }
+
+ public void setTenantName(String tenantName) {
+ this.tenantName = tenantName;
+ }
+
+ public String getDatabaseName() {
+ return databaseName;
+ }
+
+ public void setDatabaseName(String databaseName) {
+ this.databaseName = databaseName;
+ }
public String getCollectionName() {
return this.collectionName;
diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java
index d84d3e9dc66..62c2f87b51e 100644
--- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java
+++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java
@@ -58,7 +58,7 @@
public class ChromaVectorStoreAutoConfigurationIT {
@Container
- static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.20");
+ static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:1.0.0");
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations
@@ -69,7 +69,6 @@ public class ChromaVectorStoreAutoConfigurationIT {
"spring.ai.vectorstore.chroma.collectionName=TestCollection");
@Test
- @Disabled("This throws an Invalid HTTP request exception - will investigate")
public void addAndSearchWithFilters() {
this.contextRunner.withPropertyValues("spring.ai.vectorstore.chroma.initializeSchema=true").run(context -> {
diff --git a/spring-ai-spring-boot-testcontainers/pom.xml b/spring-ai-spring-boot-testcontainers/pom.xml
index 240c7a3e8b9..827b1f91dbe 100644
--- a/spring-ai-spring-boot-testcontainers/pom.xml
+++ b/spring-ai-spring-boot-testcontainers/pom.xml
@@ -295,6 +295,7 @@
org.testcontainers
chromadb
+ 1.21.0
true
diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java
index 3655f350ebc..3cb7f1a4eba 100644
--- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java
+++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java
@@ -23,7 +23,7 @@
*/
public final class ChromaImage {
- public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20");
+ public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:1.0.0");
private ChromaImage() {
diff --git a/vector-stores/spring-ai-chroma-store/pom.xml b/vector-stores/spring-ai-chroma-store/pom.xml
index 63aebf3c845..bc885767421 100644
--- a/vector-stores/spring-ai-chroma-store/pom.xml
+++ b/vector-stores/spring-ai-chroma-store/pom.xml
@@ -66,6 +66,7 @@
org.testcontainers
chromadb
+ 1.21.0
test
diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java
index a4d15249ab5..359a6da0b2e 100644
--- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java
+++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023-2024 the original author or authors.
+ * Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -29,11 +29,12 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chroma.vectorstore.ChromaApi.QueryRequest.Include;
+import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
-import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.client.HttpClientErrorException;
@@ -46,9 +47,14 @@
*
* @author Christian Tzolov
* @author Eddú Meléndez
+ * @author Jonghoon Park
*/
public class ChromaApi {
+ public static Builder builder() {
+ return new Builder();
+ }
+
// Regular expression pattern that looks for a message inside the ValueError(...).
private static final Pattern VALUE_ERROR_PATTERN = Pattern.compile("ValueError\\('([^']*)'\\)");
@@ -62,14 +68,6 @@ public class ChromaApi {
@Nullable
private String keyToken;
- public ChromaApi(String baseUrl) {
- this(baseUrl, RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()), new ObjectMapper());
- }
-
- public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder) {
- this(baseUrl, restClientBuilder, new ObjectMapper());
- }
-
public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder, ObjectMapper objectMapper) {
this.restClient = restClientBuilder.baseUrl(baseUrl)
@@ -116,10 +114,87 @@ public List toEmbeddingResponseList(@Nullable QueryResponse queryResp
}
@Nullable
- public Collection createCollection(CreateCollectionRequest createCollectionRequest) {
+ public void createTenant(String tenantName) {
+
+ this.restClient.post()
+ .uri("/api/v2/tenants")
+ .headers(this::httpHeaders)
+ .body(new CreateTenantRequest(tenantName))
+ .retrieve()
+ .toBodilessEntity();
+ }
+
+ @Nullable
+ public Tenant getTenant(String tenantName) {
+
+ try {
+ return this.restClient.get()
+ .uri("/api/v2/tenants/{tenant_name}", tenantName)
+ .headers(this::httpHeaders)
+ .retrieve()
+ .toEntity(Tenant.class)
+ .getBody();
+ }
+ catch (HttpServerErrorException | HttpClientErrorException e) {
+ String msg = this.getErrorMessage(e);
+ if (String.format("Tenant [%s] not found", tenantName).equals(msg)) {
+ return null;
+ }
+ throw new RuntimeException(msg, e);
+ }
+ }
+
+ @Nullable
+ public void createDatabase(String tenantName, String databaseName) {
+
+ this.restClient.post()
+ .uri("/api/v2/tenants/{tenant_name}/databases", tenantName)
+ .headers(this::httpHeaders)
+ .body(new CreateDatabaseRequest(databaseName))
+ .retrieve()
+ .toBodilessEntity();
+ }
+
+ @Nullable
+ public Database getDatabase(String tenantName, String databaseName) {
+
+ try {
+ return this.restClient.get()
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}", tenantName, databaseName)
+ .headers(this::httpHeaders)
+ .retrieve()
+ .toEntity(Database.class)
+ .getBody();
+ }
+ catch (HttpServerErrorException | HttpClientErrorException e) {
+ String msg = this.getErrorMessage(e);
+ if (msg.startsWith(String.format("Database [%s] not found.", databaseName))) {
+ return null;
+ }
+ throw new RuntimeException(msg, e);
+ }
+ }
+
+ /**
+ * Delete a database with the given name.
+ * @param tenantName the name of the tenant to delete.
+ * @param databaseName the name of the database to delete.
+ */
+ public void deleteDatabase(String tenantName, String databaseName) {
+
+ this.restClient.delete()
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}", tenantName, databaseName)
+ .headers(this::httpHeaders)
+ .retrieve()
+ .toBodilessEntity();
+ }
+
+ @Nullable
+ public Collection createCollection(String tenantName, String databaseName,
+ CreateCollectionRequest createCollectionRequest) {
return this.restClient.post()
- .uri("/api/v1/collections")
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections", tenantName, databaseName)
.headers(this::httpHeaders)
.body(createCollectionRequest)
.retrieve()
@@ -132,21 +207,23 @@ public Collection createCollection(CreateCollectionRequest createCollectionReque
* @param collectionName the name of the collection to delete.
*
*/
- public void deleteCollection(String collectionName) {
+ public void deleteCollection(String tenantName, String databaseName, String collectionName) {
this.restClient.delete()
- .uri("/api/v1/collections/{collection_name}", collectionName)
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_name}", tenantName,
+ databaseName, collectionName)
.headers(this::httpHeaders)
.retrieve()
.toBodilessEntity();
}
@Nullable
- public Collection getCollection(String collectionName) {
+ public Collection getCollection(String tenantName, String databaseName, String collectionName) {
try {
return this.restClient.get()
- .uri("/api/v1/collections/{collection_name}", collectionName)
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_name}",
+ tenantName, databaseName, collectionName)
.headers(this::httpHeaders)
.retrieve()
.toEntity(Collection.class)
@@ -154,7 +231,7 @@ public Collection getCollection(String collectionName) {
}
catch (HttpServerErrorException | HttpClientErrorException e) {
String msg = this.getErrorMessage(e);
- if (String.format("Collection %s does not exist.", collectionName).equals(msg)) {
+ if (String.format("Collection [%s] does not exists", collectionName).equals(msg)) {
return null;
}
throw new RuntimeException(msg, e);
@@ -162,29 +239,33 @@ public Collection getCollection(String collectionName) {
}
@Nullable
- public List listCollections() {
+ public List listCollections(String tenantName, String databaseName) {
return this.restClient.get()
- .uri("/api/v1/collections")
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections", tenantName, databaseName)
.headers(this::httpHeaders)
.retrieve()
.toEntity(CollectionList.class)
.getBody();
}
- public void upsertEmbeddings(@Nullable String collectionId, AddEmbeddingsRequest embedding) {
+ public void upsertEmbeddings(String tenantName, String databaseName, String collectionId,
+ AddEmbeddingsRequest embedding) {
this.restClient.post()
- .uri("/api/v1/collections/{collection_id}/upsert", collectionId)
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_name}/upsert",
+ tenantName, databaseName, collectionId)
.headers(this::httpHeaders)
.body(embedding)
.retrieve()
.toBodilessEntity();
}
- public int deleteEmbeddings(@Nullable String collectionId, DeleteEmbeddingsRequest deleteRequest) {
+ public int deleteEmbeddings(String tenantName, String databaseName, String collectionId,
+ DeleteEmbeddingsRequest deleteRequest) {
return this.restClient.post()
- .uri("/api/v1/collections/{collection_id}/delete", collectionId)
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_name}/delete",
+ tenantName, databaseName, collectionId)
.headers(this::httpHeaders)
.body(deleteRequest)
.retrieve()
@@ -194,10 +275,11 @@ public int deleteEmbeddings(@Nullable String collectionId, DeleteEmbeddingsReque
}
@Nullable
- public Long countEmbeddings(String collectionId) {
+ public Long countEmbeddings(String tenantName, String databaseName, String collectionId) {
return this.restClient.get()
- .uri("/api/v1/collections/{collection_id}/count", collectionId)
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_id}/count",
+ tenantName, databaseName, collectionId)
.headers(this::httpHeaders)
.retrieve()
.toEntity(Long.class)
@@ -205,10 +287,12 @@ public Long countEmbeddings(String collectionId) {
}
@Nullable
- public QueryResponse queryCollection(@Nullable String collectionId, QueryRequest queryRequest) {
+ public QueryResponse queryCollection(String tenantName, String databaseName, String collectionId,
+ QueryRequest queryRequest) {
return this.restClient.post()
- .uri("/api/v1/collections/{collection_id}/query", collectionId)
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_id}/query",
+ tenantName, databaseName, collectionId)
.headers(this::httpHeaders)
.body(queryRequest)
.retrieve()
@@ -220,10 +304,12 @@ public QueryResponse queryCollection(@Nullable String collectionId, QueryRequest
// Chroma Client API (https://docs.trychroma.com/js_reference/Client)
//
@Nullable
- public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) {
+ public GetEmbeddingResponse getEmbeddings(String tenantName, String databaseName, String collectionId,
+ GetEmbeddingsRequest getEmbeddingsRequest) {
return this.restClient.post()
- .uri("/api/v1/collections/{collection_id}/get", collectionId)
+ .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_id}/get", tenantName,
+ databaseName, collectionId)
.headers(this::httpHeaders)
.body(getEmbeddingsRequest)
.retrieve()
@@ -271,6 +357,42 @@ private String getErrorMessage(HttpStatusCodeException e) {
return "";
}
+ /**
+ * Request to create a new tenant
+ *
+ * @param name The name of the tenant to create.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record CreateTenantRequest(@JsonProperty("name") String name) {
+ }
+
+ /**
+ * Chroma tenant.
+ *
+ * @param name The name of the tenant.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record Tenant(@JsonProperty("name") String name) {
+ }
+
+ /**
+ * Request to create a new database
+ *
+ * @param name The name of the database to create.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record CreateDatabaseRequest(@JsonProperty("name") String name) {
+ }
+
+ /**
+ * Chroma database.
+ *
+ * @param name The name of the database.
+ */
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public record Database(@JsonProperty("name") String name) {
+ }
+
/**
* Chroma embedding collection.
*
@@ -487,4 +609,36 @@ private static class CollectionList extends ArrayList {
}
+ public static class Builder {
+
+ private String baseUrl = ChromaApiConstants.DEFAULT_BASE_URL;
+
+ private RestClient.Builder restClientBuilder = RestClient.builder();
+
+ private ObjectMapper objectMapper = new ObjectMapper();
+
+ public Builder baseUrl(String baseUrl) {
+ Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
+ this.baseUrl = baseUrl;
+ return this;
+ }
+
+ public Builder restClientBuilder(RestClient.Builder restClientBuilder) {
+ Assert.notNull(restClientBuilder, "restClientBuilder cannot be null");
+ this.restClientBuilder = restClientBuilder;
+ return this;
+ }
+
+ public Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "objectMapper cannot be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ public ChromaApi build() {
+ return new ChromaApi(this.baseUrl, this.restClientBuilder, objectMapper);
+ }
+
+ }
+
}
diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java
index 2e49f843c84..b56b99673a5 100644
--- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java
+++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java
@@ -30,6 +30,7 @@
import org.springframework.ai.chroma.vectorstore.ChromaApi.AddEmbeddingsRequest;
import org.springframework.ai.chroma.vectorstore.ChromaApi.DeleteEmbeddingsRequest;
import org.springframework.ai.chroma.vectorstore.ChromaApi.Embedding;
+import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
@@ -61,13 +62,16 @@
* @author Sebastien Deleuze
* @author Soby Chacko
* @author Thomas Vitale
+ * @author Jonghoon Park
*/
public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean {
- public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection";
-
private final ChromaApi chromaApi;
+ private final String tenantName;
+
+ private final String databaseName;
+
private final String collectionName;
private FilterExpressionConverter filterExpressionConverter;
@@ -90,6 +94,8 @@ protected ChromaVectorStore(Builder builder) {
super(builder);
this.chromaApi = builder.chromaApi;
+ this.tenantName = builder.tenantName;
+ this.databaseName = builder.databaseName;
this.collectionName = builder.collectionName;
this.initializeSchema = builder.initializeSchema;
this.filterExpressionConverter = builder.filterExpressionConverter;
@@ -112,11 +118,21 @@ public static Builder builder(ChromaApi chromaApi, EmbeddingModel embeddingModel
@Override
public void afterPropertiesSet() throws Exception {
if (!this.initialized) {
- var collection = this.chromaApi.getCollection(this.collectionName);
+ var collection = this.chromaApi.getCollection(this.tenantName, this.databaseName, this.collectionName);
if (collection == null) {
if (this.initializeSchema) {
- collection = this.chromaApi
- .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName));
+ var tenant = this.chromaApi.getTenant(this.tenantName);
+ if (tenant == null) {
+ this.chromaApi.createTenant(this.tenantName);
+ }
+
+ var database = this.chromaApi.getDatabase(this.tenantName, this.databaseName);
+ if (database == null) {
+ this.chromaApi.createDatabase(this.tenantName, this.databaseName);
+ }
+
+ collection = this.chromaApi.createCollection(this.tenantName, this.databaseName,
+ new ChromaApi.CreateCollectionRequest(this.collectionName));
}
else {
throw new RuntimeException("Collection " + this.collectionName
@@ -152,14 +168,15 @@ public void doAdd(@NonNull List documents) {
embeddings.add(documentEmbeddings.get(documents.indexOf(document)));
}
- this.chromaApi.upsertEmbeddings(this.collectionId,
+ this.chromaApi.upsertEmbeddings(this.tenantName, this.databaseName, this.collectionId,
new AddEmbeddingsRequest(ids, embeddings, metadatas, contents));
}
@Override
public void doDelete(List idList) {
Assert.notNull(idList, "Document id list must not be null");
- this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList));
+ this.chromaApi.deleteEmbeddings(this.tenantName, this.databaseName, this.collectionId,
+ new DeleteEmbeddingsRequest(idList));
}
@Override
@@ -175,7 +192,7 @@ protected void doDelete(Filter.Expression expression) {
logger.debug("Deleting with where clause: " + whereClause);
DeleteEmbeddingsRequest deleteRequest = new DeleteEmbeddingsRequest(null, whereClause);
- this.chromaApi.deleteEmbeddings(this.collectionId, deleteRequest);
+ this.chromaApi.deleteEmbeddings(this.tenantName, this.databaseName, this.collectionId, deleteRequest);
}
catch (Exception e) {
logger.error("Failed to delete documents by filter: {}", e.getMessage(), e);
@@ -196,7 +213,8 @@ public List doSimilaritySearch(@NonNull SearchRequest request) {
? jsonToMap(this.filterExpressionConverter.convertExpression(request.getFilterExpression())) : null;
var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where);
- var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest);
+ var queryResponse = this.chromaApi.queryCollection(this.tenantName, this.databaseName, this.collectionId,
+ queryRequest);
var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse);
List responseDocuments = new ArrayList<>();
@@ -242,11 +260,29 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
.collectionName(this.collectionName + ":" + this.collectionId);
}
+ // used by the test
+ void createCollection() {
+ var collection = this.chromaApi.createCollection(this.tenantName, this.databaseName,
+ new ChromaApi.CreateCollectionRequest(this.collectionName));
+ if (collection != null) {
+ this.collectionId = collection.id();
+ }
+ }
+
+ // used by the test
+ void deleteCollection() {
+ this.chromaApi.deleteCollection(this.tenantName, this.databaseName, this.collectionName);
+ }
+
public static class Builder extends AbstractVectorStoreBuilder {
private final ChromaApi chromaApi;
- private String collectionName = DEFAULT_COLLECTION_NAME;
+ private String tenantName = ChromaApiConstants.DEFAULT_TENANT_NAME;
+
+ private String databaseName = ChromaApiConstants.DEFAULT_DATABASE_NAME;
+
+ private String collectionName = ChromaApiConstants.DEFAULT_COLLECTION_NAME;
private boolean initializeSchema = false;
@@ -260,6 +296,30 @@ private Builder(ChromaApi chromaApi, EmbeddingModel embeddingModel) {
this.chromaApi = chromaApi;
}
+ /**
+ * Sets the tenant name.
+ * @param tenantName the name of the tenant
+ * @return the builder instance
+ * @throws IllegalArgumentException if collectionName is null or empty
+ */
+ public Builder tenantName(String tenantName) {
+ Assert.hasText(tenantName, "tenantName must not be null or empty");
+ this.tenantName = tenantName;
+ return this;
+ }
+
+ /**
+ * Sets the database name.
+ * @param databaseName the name of the database
+ * @return the builder instance
+ * @throws IllegalArgumentException if collectionName is null or empty
+ */
+ public Builder databaseName(String databaseName) {
+ Assert.hasText(databaseName, "databaseName must not be null or empty");
+ this.databaseName = databaseName;
+ return this;
+ }
+
/**
* Sets the collection name.
* @param collectionName the name of the collection
diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/common/ChromaApiConstants.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/common/ChromaApiConstants.java
new file mode 100644
index 00000000000..80e01975e42
--- /dev/null
+++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/common/ChromaApiConstants.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2023-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.chroma.vectorstore.common;
+
+/**
+ * Common value constants for Chroma api.
+ *
+ * @author Jonghoon Park
+ */
+public class ChromaApiConstants {
+
+ public static final String DEFAULT_BASE_URL = "http://localhost:8000";
+
+ public static final String DEFAULT_TENANT_NAME = "SpringAiTenant";
+
+ public static final String DEFAULT_DATABASE_NAME = "SpringAiDatabase";
+
+ public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection";
+
+}
diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaImage.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaImage.java
index a3484542682..4893958b3f6 100644
--- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaImage.java
+++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaImage.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023-2024 the original author or authors.
+ * Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@
*/
public final class ChromaImage {
- public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20");
+ public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:1.0.0");
private ChromaImage() {
diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java
index 88ef6a80ca3..f6216e05bd5 100644
--- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java
+++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023-2024 the original author or authors.
+ * Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -22,6 +22,8 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.testcontainers.chromadb.ChromaDBContainer;
+import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy;
+import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.MountableFile;
@@ -51,6 +53,7 @@
* @author Christian Tzolov
* @author Eddú Meléndez
* @author Thomas Vitale
+ * @author Jonghoon Park
*/
@Testcontainers
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@@ -108,7 +111,11 @@ public RestClient.Builder builder() {
@Bean
public ChromaApi chromaApi(RestClient.Builder builder) {
- return new ChromaApi(chromaContainer.getEndpoint(), builder).withBasicAuthCredentials("admin", "password");
+ return ChromaApi.builder()
+ .baseUrl(chromaContainer.getEndpoint())
+ .restClientBuilder(builder)
+ .build()
+ .withBasicAuthCredentials("admin", "password");
}
@Bean
diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaApiIT.java
index 270dcd30612..5d59526342d 100644
--- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaApiIT.java
+++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaApiIT.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023-2024 the original author or authors.
+ * Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -22,7 +22,10 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants;
import org.testcontainers.chromadb.ChromaDBContainer;
+import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy;
+import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -48,6 +51,7 @@
* @author Eddú Meléndez
* @author Thomas Vitale
* @author Soby Chacko
+ * @author Jonghoon Park
*/
@SpringBootTest
@Testcontainers
@@ -56,6 +60,10 @@ public class ChromaApiIT {
@Container
static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE);
+ final String defaultTenantName = ChromaApiConstants.DEFAULT_TENANT_NAME;
+
+ final String defaultDatabaseName = ChromaApiConstants.DEFAULT_DATABASE_NAME;
+
@Autowired
ChromaApi chromaApi;
@@ -64,57 +72,74 @@ public class ChromaApiIT {
@BeforeEach
public void beforeEach() {
- this.chromaApi.listCollections().stream().forEach(c -> this.chromaApi.deleteCollection(c.name()));
+ var tenant = this.chromaApi.getTenant(defaultTenantName);
+ if (tenant == null) {
+ this.chromaApi.createTenant(defaultTenantName);
+ }
+
+ var database = this.chromaApi.getDatabase(defaultTenantName, defaultDatabaseName);
+ if (database == null) {
+ this.chromaApi.createDatabase(defaultTenantName, defaultDatabaseName);
+ }
+
+ this.chromaApi.listCollections(defaultTenantName, defaultDatabaseName)
+ .forEach(c -> this.chromaApi.deleteCollection(defaultTenantName, defaultDatabaseName, c.name()));
}
@Test
public void testClientWithMetadata() {
Map metadata = Map.of("hnsw:space", "cosine", "hnsw:M", 5);
- var newCollection = this.chromaApi
- .createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata));
+ var newCollection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName,
+ new ChromaApi.CreateCollectionRequest("TestCollection", metadata));
assertThat(newCollection).isNotNull();
assertThat(newCollection.name()).isEqualTo("TestCollection");
}
@Test
public void testClient() {
- var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
+ var newCollection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName,
+ new ChromaApi.CreateCollectionRequest("TestCollection"));
assertThat(newCollection).isNotNull();
assertThat(newCollection.name()).isEqualTo("TestCollection");
- var getCollection = this.chromaApi.getCollection("TestCollection");
+ var getCollection = this.chromaApi.getCollection(defaultTenantName, defaultDatabaseName, "TestCollection");
assertThat(getCollection).isNotNull();
assertThat(getCollection.name()).isEqualTo("TestCollection");
assertThat(getCollection.id()).isEqualTo(newCollection.id());
- List collections = this.chromaApi.listCollections();
+ List collections = this.chromaApi.listCollections(defaultTenantName, defaultDatabaseName);
assertThat(collections).hasSize(1);
assertThat(collections.get(0).id()).isEqualTo(newCollection.id());
- this.chromaApi.deleteCollection(newCollection.name());
- assertThat(this.chromaApi.listCollections()).hasSize(0);
+ this.chromaApi.deleteCollection(defaultTenantName, defaultDatabaseName, newCollection.name());
+ assertThat(this.chromaApi.listCollections(defaultTenantName, defaultDatabaseName)).hasSize(0);
}
@Test
public void testCollection() {
- var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
- assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(0);
+ var newCollection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName,
+ new ChromaApi.CreateCollectionRequest("TestCollection"));
+ assertThat(this.chromaApi.countEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id()))
+ .isEqualTo(0);
var addEmbeddingRequest = new AddEmbeddingsRequest(List.of("id1", "id2"),
List.of(new float[] { 1f, 1f, 1f }, new float[] { 2f, 2f, 2f }),
List.of(Map.of(), Map.of("key1", "value1", "key2", true, "key3", 23.4)),
List.of("Hello World", "Big World"));
- this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest);
+ this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id(),
+ addEmbeddingRequest);
var addEmbeddingRequest2 = new AddEmbeddingsRequest("id3", new float[] { 3f, 3f, 3f },
Map.of("key1", "value1", "key2", true, "key3", 23.4), "Big World");
- this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2);
+ this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id(),
+ addEmbeddingRequest2);
- assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(3);
+ assertThat(this.chromaApi.countEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id()))
+ .isEqualTo(3);
- var queryResult = this.chromaApi.queryCollection(newCollection.id(),
+ var queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, newCollection.id(),
new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where("""
{
"key2" : { "$eq": true }
@@ -124,13 +149,15 @@ public void testCollection() {
assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id2", "id3");
// Update existing embedding.
- this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f },
- Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World"));
+ this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id(),
+ new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f },
+ Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World"));
- var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2")));
+ var result = this.chromaApi.getEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id(),
+ new GetEmbeddingsRequest(List.of("id2")));
assertThat(result.ids().get(0)).isEqualTo("id2");
- queryResult = this.chromaApi.queryCollection(newCollection.id(),
+ queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, newCollection.id(),
new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where("""
{
"key2" : { "$eq": true }
@@ -143,7 +170,8 @@ public void testCollection() {
@Test
public void testQueryWhere() {
- var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection"));
+ var collection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName,
+ new ChromaApi.CreateCollectionRequest("TestCollection"));
var add1 = new AddEmbeddingsRequest("id1", new float[] { 1f, 1f, 1f },
Map.of("country", "BG", "active", true, "price", 23.4, "year", 2020),
@@ -156,13 +184,14 @@ public void testQueryWhere() {
Map.of("country", "BG", "active", false, "price", 40.1, "year", 2023),
"The World is Big and Salvation Lurks Around the Corner");
- this.chromaApi.upsertEmbeddings(collection.id(), add1);
- this.chromaApi.upsertEmbeddings(collection.id(), add2);
- this.chromaApi.upsertEmbeddings(collection.id(), add3);
+ this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, collection.id(), add1);
+ this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, collection.id(), add2);
+ this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, collection.id(), add3);
- assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3);
+ assertThat(this.chromaApi.countEmbeddings(defaultTenantName, defaultDatabaseName, collection.id()))
+ .isEqualTo(3);
- var queryResult = this.chromaApi.queryCollection(collection.id(),
+ var queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, collection.id(),
new QueryRequest(new float[] { 1f, 1f, 1f }, 3));
assertThat(queryResult.ids().get(0)).hasSize(3);
@@ -173,7 +202,7 @@ public void testQueryWhere() {
assertThat(chromaEmbeddings).hasSize(3);
assertThat(chromaEmbeddings).hasSize(3);
- queryResult = this.chromaApi.queryCollection(collection.id(),
+ queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, collection.id(),
new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where("""
{
"$and" : [
@@ -185,7 +214,7 @@ public void testQueryWhere() {
assertThat(queryResult.ids().get(0)).hasSize(2);
assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id3");
- queryResult = this.chromaApi.queryCollection(collection.id(),
+ queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, collection.id(),
new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where("""
{
"$and" : [
@@ -203,7 +232,8 @@ public void testQueryWhere() {
void shouldUseExistingCollectionWhenSchemaInitializationDisabled() { // initializeSchema
// is false by
// default.
- var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("test-collection"));
+ var collection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName,
+ new ChromaApi.CreateCollectionRequest("test-collection"));
assertThat(collection).isNotNull();
assertThat(collection.name()).isEqualTo("test-collection");
@@ -224,7 +254,7 @@ void shouldCreateNewCollectionWhenSchemaInitializationEnabled() {
.initializeImmediately(true)
.build();
- var collection = this.chromaApi.getCollection("new-collection");
+ var collection = this.chromaApi.getCollection(defaultTenantName, defaultDatabaseName, "new-collection");
assertThat(collection).isNotNull();
assertThat(collection.name()).isEqualTo("new-collection");
@@ -250,7 +280,7 @@ public static class Config {
@Bean
public ChromaApi chromaApi() {
- return new ChromaApi(chromaContainer.getEndpoint());
+ return ChromaApi.builder().baseUrl(chromaContainer.getEndpoint()).build();
}
@Bean
diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java
index 9ab4bc3ac14..b3252c67341 100644
--- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java
+++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java
@@ -25,6 +25,8 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.testcontainers.chromadb.ChromaDBContainer;
+import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy;
+import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -49,6 +51,7 @@
* @author Christian Tzolov
* @author Eddú Meléndez
* @author Thomas Vitale
+ * @author Jonghoon Park
*/
@Testcontainers
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@@ -57,6 +60,11 @@ public class ChromaVectorStoreIT extends BaseVectorStoreTests {
@Container
static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE);
+ private void resetCollection(VectorStore vectorStore) {
+ ((ChromaVectorStore) vectorStore).deleteCollection();
+ ((ChromaVectorStore) vectorStore).createCollection();
+ }
+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withUserConfiguration(TestApplication.class)
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"));
@@ -83,6 +91,8 @@ public void addAndSearch() {
VectorStore vectorStore = context.getBean(VectorStore.class);
+ resetCollection(vectorStore);
+
vectorStore.add(this.documents);
List results = vectorStore
@@ -110,6 +120,8 @@ public void simpleSearch() {
VectorStore vectorStore = context.getBean(VectorStore.class);
+ resetCollection(vectorStore);
+
var document = Document.builder()
.id("simpleDoc")
.text("The sky is blue because of Rayleigh scattering.")
@@ -139,6 +151,8 @@ public void addAndSearchWithFilters() {
VectorStore vectorStore = context.getBean(VectorStore.class);
+ resetCollection(vectorStore);
+
var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner",
Map.of("country", "Bulgaria"));
var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner",
@@ -185,6 +199,8 @@ public void documentUpdateTest() {
VectorStore vectorStore = context.getBean(VectorStore.class);
+ resetCollection(vectorStore);
+
Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!",
Collections.singletonMap("meta1", "meta1"));
@@ -227,6 +243,8 @@ public void searchThresholdTest() {
VectorStore vectorStore = context.getBean(VectorStore.class);
+ resetCollection(vectorStore);
+
vectorStore.add(this.documents);
var request = SearchRequest.builder().query("Great").topK(5).build();
@@ -266,7 +284,7 @@ public RestClient.Builder builder() {
@Bean
public ChromaApi chromaApi(RestClient.Builder builder) {
- return new ChromaApi(chromaContainer.getEndpoint(), builder);
+ return ChromaApi.builder().baseUrl(chromaContainer.getEndpoint()).restClientBuilder(builder).build();
}
@Bean
diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java
index ef28e268473..693973617b1 100644
--- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java
+++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023-2024 the original author or authors.
+ * Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,6 +27,8 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.testcontainers.chromadb.ChromaDBContainer;
+import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy;
+import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -57,6 +59,7 @@
/**
* @author Christian Tzolov
* @author Thomas Vitale
+ * @author Jonghoon Park
*/
@Testcontainers
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@@ -170,7 +173,7 @@ public RestClient.Builder builder() {
@Bean
public ChromaApi chromaApi(RestClient.Builder builder) {
- return new ChromaApi(chromaContainer.getEndpoint(), builder);
+ return ChromaApi.builder().baseUrl(chromaContainer.getEndpoint()).restClientBuilder(builder).build();
}
@Bean
diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java
index 00997ae4f44..73b38911b07 100644
--- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java
+++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java
@@ -22,6 +22,8 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.testcontainers.chromadb.ChromaDBContainer;
+import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy;
+import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -143,7 +145,10 @@ public RestClient.Builder builder() {
@Bean
public ChromaApi chromaApi(RestClient.Builder builder) {
- var chromaApi = new ChromaApi(chromaContainer.getEndpoint(), builder);
+ var chromaApi = ChromaApi.builder()
+ .baseUrl(chromaContainer.getEndpoint())
+ .restClientBuilder(builder)
+ .build();
chromaApi.withKeyToken(CHROMA_SERVER_AUTH_CREDENTIALS);
return chromaApi;
}
diff --git a/vector-stores/spring-ai-chroma-store/src/test/resources/api.yaml b/vector-stores/spring-ai-chroma-store/src/test/resources/api.yaml
deleted file mode 100644
index 924c8a046d1..00000000000
--- a/vector-stores/spring-ai-chroma-store/src/test/resources/api.yaml
+++ /dev/null
@@ -1,630 +0,0 @@
-openapi: "3.0.0"
-info:
- title: FastAPI
- version: 0.1.0
-paths:
- /api/v1:
- get:
- summary: Root
- operationId: root
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema:
- additionalProperties:
- type: integer
- type: object
- title: Response Root Api V1 Get
- /api/v1/reset:
- post:
- summary: Reset
- operationId: reset
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema:
- type: boolean
- title: Response Reset Api V1 Reset Post
- /api/v1/version:
- get:
- summary: Version
- operationId: version
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema:
- type: string
- title: Response Version Api V1 Version Get
- /api/v1/heartbeat:
- get:
- summary: Heartbeat
- operationId: heartbeat
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema:
- additionalProperties:
- type: number
- type: object
- title: Response Heartbeat Api V1 Heartbeat Get
- /api/v1/raw_sql:
- post:
- summary: Raw Sql
- operationId: raw_sql
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/RawSql'
- required: true
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections:
- get:
- summary: List Collections
- operationId: list_collections
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- post:
- summary: Create Collection
- operationId: create_collection
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/CreateCollection'
- required: true
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_id}/add:
- post:
- summary: Add
- operationId: add
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Id
- name: collection_id
- in: path
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/AddEmbedding'
- required: true
- responses:
- '201':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_id}/update:
- post:
- summary: Update
- operationId: update
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Id
- name: collection_id
- in: path
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/UpdateEmbedding'
- required: true
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_id}/upsert:
- post:
- summary: Upsert
- operationId: upsert
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Id
- name: collection_id
- in: path
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/AddEmbedding'
- required: true
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_id}/get:
- post:
- summary: Get
- operationId: get
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Id
- name: collection_id
- in: path
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/GetEmbedding'
- required: true
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {} #TODO add actual GetResult Body
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_id}/delete:
- post:
- summary: Delete
- operationId: delete
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Id
- name: collection_id
- in: path
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/DeleteEmbedding'
- required: true
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_id}/count:
- get:
- summary: Count
- operationId: count
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Id
- name: collection_id
- in: path
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema:
- type: integer
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_id}/query:
- post:
- summary: Get Nearest Neighbors
- operationId: get_nearest_neighbors
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Id
- name: collection_id
- in: path
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/QueryEmbedding'
- required: true
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_name}/create_index:
- post:
- summary: Create Index
- operationId: create_index
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Name
- name: collection_name
- in: path
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema:
- type: boolean
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_name}:
- get:
- summary: Get Collection
- operationId: get_collection
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Name
- name: collection_name
- in: path
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- delete:
- summary: Delete Collection
- operationId: delete_collection
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Name
- name: collection_name
- in: path
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
- /api/v1/collections/{collection_id}:
- put:
- summary: Update Collection
- operationId: update_collection
- parameters:
- - required: true
- schema:
- type: string
- title: Collection Id
- name: collection_id
- in: path
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/UpdateCollection'
- required: true
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
-components:
- schemas:
- AddEmbedding:
- properties:
- embeddings:
- items: {}
- type: array
- title: Embeddings
- metadatas:
- items:
- type: object
- additionalProperties: true
- type: array
- title: Metadatas
- documents:
- items:
- type: string
- type: array
- title: Documents
- ids:
- items:
- type: string
- type: array
- title: Ids
- increment_index:
- type: boolean
- title: Increment Index
- default: true
- type: object
- required:
- - ids
- title: AddEmbedding
- CreateCollection:
- properties:
- name:
- type: string
- title: Name
- metadata:
- type: object
- title: Metadata
- get_or_create:
- type: boolean
- title: Get Or Create
- default: false
- type: object
- required:
- - name
- title: CreateCollection
- DeleteEmbedding:
- properties:
- ids:
- items:
- type: string
- type: array
- title: Ids
- where:
- type: object
- title: Where
- where_document:
- type: object
- title: Where Document
- type: object
- title: DeleteEmbedding
- GetEmbedding:
- properties:
- ids:
- items:
- type: string
- type: array
- title: Ids
- where:
- type: object
- title: Where
- where_document:
- type: object
- title: Where Document
- sort:
- type: string
- title: Sort
- limit:
- type: integer
- title: Limit
- offset:
- type: integer
- title: Offset
- include:
- items:
- anyOf:
- - type: string
- enum:
- - documents
- - type: string
- enum:
- - embeddings
- - type: string
- enum:
- - metadatas
- - type: string
- enum:
- - distances
- type: array
- title: Include
- default:
- - metadatas
- - documents
- type: object
- title: GetEmbedding
- HTTPValidationError:
- properties:
- detail:
- items:
- $ref: '#/components/schemas/ValidationError'
- type: array
- title: Detail
- type: object
- title: HTTPValidationError
- QueryEmbedding:
- properties:
- where:
- type: object
- title: Where
- additionalProperties: true
- default: {}
- where_document:
- type: object
- title: Where Document
- additionalProperties: true
- default: {}
- query_embeddings:
- items: {}
- type: array
- additionalProperties: true
- title: Query Embeddings
- n_results:
- type: integer
- title: N Results
- default: 10
- include:
- items:
- type: string
- enum:
- - documents
- - embeddings
- - metadatas
- - distances
- type: array
- title: Include
- default:
- - metadatas
- - documents
- - distances
- type: object
- required:
- - query_embeddings
- title: QueryEmbedding
- RawSql:
- properties:
- raw_sql:
- type: string
- title: Raw Sql
- type: object
- required:
- - raw_sql
- title: RawSql
- UpdateCollection:
- properties:
- new_name:
- type: string
- title: New Name
- new_metadata:
- type: object
- title: New Metadata
- type: object
- title: UpdateCollection
- UpdateEmbedding:
- properties:
- embeddings:
- items: {}
- type: array
- title: Embeddings
- metadatas:
- items:
- type: object
- type: array
- title: Metadatas
- documents:
- items:
- type: string
- type: array
- title: Documents
- ids:
- items:
- type: string
- type: array
- title: Ids
- increment_index:
- type: boolean
- title: Increment Index
- default: true
- type: object
- required:
- - ids
- title: UpdateEmbedding
- ValidationError:
- properties:
- loc:
- items:
- anyOf:
- - type: string
- - type: integer
- type: array
- title: Location
- msg:
- type: string
- title: Message
- type:
- type: string
- title: Error Type
- type: object
- required:
- - loc
- - msg
- - type
- title: ValidationError
\ No newline at end of file