diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/pom.xml b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/pom.xml index a852dc773fe..9eee24b3efb 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/pom.xml +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/pom.xml @@ -94,6 +94,7 @@ org.testcontainers chromadb + 1.21.0 test diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfiguration.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfiguration.java index 6f96c5aaa4f..4af9a68728b 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfiguration.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,8 +65,11 @@ public ChromaApi chromaApi(ChromaApiProperties apiProperties, String chromaUrl = String.format("%s:%s", connectionDetails.getHost(), connectionDetails.getPort()); - var chromaApi = new ChromaApi(chromaUrl, restClientBuilderProvider.getIfAvailable(RestClient::builder), - objectMapper); + var chromaApi = ChromaApi.builder() + .baseUrl(chromaUrl) + .restClientBuilder(restClientBuilderProvider.getIfAvailable(RestClient::builder)) + .objectMapper(objectMapper) + .build(); if (StringUtils.hasText(connectionDetails.getKeyToken())) { chromaApi.withKeyToken(connectionDetails.getKeyToken()); diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreProperties.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreProperties.java index 3098eb4e5e2..92931fe83a0 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreProperties.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/main/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package org.springframework.ai.vectorstore.chroma.autoconfigure; -import org.springframework.ai.chroma.vectorstore.ChromaVectorStore; +import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants; import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -25,13 +25,34 @@ * * @author Christian Tzolov * @author Soby Chacko + * @author Jonghoon Park */ @ConfigurationProperties(ChromaVectorStoreProperties.CONFIG_PREFIX) public class ChromaVectorStoreProperties extends CommonVectorStoreProperties { public static final String CONFIG_PREFIX = "spring.ai.vectorstore.chroma"; - private String collectionName = ChromaVectorStore.DEFAULT_COLLECTION_NAME; + private String tenantName = ChromaApiConstants.DEFAULT_TENANT_NAME; + + private String databaseName = ChromaApiConstants.DEFAULT_DATABASE_NAME; + + private String collectionName = ChromaApiConstants.DEFAULT_COLLECTION_NAME; + + public String getTenantName() { + return tenantName; + } + + public void setTenantName(String tenantName) { + this.tenantName = tenantName; + } + + public String getDatabaseName() { + return databaseName; + } + + public void setDatabaseName(String databaseName) { + this.databaseName = databaseName; + } public String getCollectionName() { return this.collectionName; diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java index d84d3e9dc66..62c2f87b51e 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java @@ -58,7 +58,7 @@ public class ChromaVectorStoreAutoConfigurationIT { @Container - static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.20"); + static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:1.0.0"); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations @@ -69,7 +69,6 @@ public class ChromaVectorStoreAutoConfigurationIT { "spring.ai.vectorstore.chroma.collectionName=TestCollection"); @Test - @Disabled("This throws an Invalid HTTP request exception - will investigate") public void addAndSearchWithFilters() { this.contextRunner.withPropertyValues("spring.ai.vectorstore.chroma.initializeSchema=true").run(context -> { diff --git a/spring-ai-spring-boot-testcontainers/pom.xml b/spring-ai-spring-boot-testcontainers/pom.xml index 240c7a3e8b9..827b1f91dbe 100644 --- a/spring-ai-spring-boot-testcontainers/pom.xml +++ b/spring-ai-spring-boot-testcontainers/pom.xml @@ -295,6 +295,7 @@ org.testcontainers chromadb + 1.21.0 true diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java index 3655f350ebc..3cb7f1a4eba 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java @@ -23,7 +23,7 @@ */ public final class ChromaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:1.0.0"); private ChromaImage() { diff --git a/vector-stores/spring-ai-chroma-store/pom.xml b/vector-stores/spring-ai-chroma-store/pom.xml index 63aebf3c845..bc885767421 100644 --- a/vector-stores/spring-ai-chroma-store/pom.xml +++ b/vector-stores/spring-ai-chroma-store/pom.xml @@ -66,6 +66,7 @@ org.testcontainers chromadb + 1.21.0 test diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java index a4d15249ab5..359a6da0b2e 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaApi.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,11 +29,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chroma.vectorstore.ChromaApi.QueryRequest.Include; +import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; -import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.client.HttpClientErrorException; @@ -46,9 +47,14 @@ * * @author Christian Tzolov * @author Eddú Meléndez + * @author Jonghoon Park */ public class ChromaApi { + public static Builder builder() { + return new Builder(); + } + // Regular expression pattern that looks for a message inside the ValueError(...). private static final Pattern VALUE_ERROR_PATTERN = Pattern.compile("ValueError\\('([^']*)'\\)"); @@ -62,14 +68,6 @@ public class ChromaApi { @Nullable private String keyToken; - public ChromaApi(String baseUrl) { - this(baseUrl, RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()), new ObjectMapper()); - } - - public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder) { - this(baseUrl, restClientBuilder, new ObjectMapper()); - } - public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder, ObjectMapper objectMapper) { this.restClient = restClientBuilder.baseUrl(baseUrl) @@ -116,10 +114,87 @@ public List toEmbeddingResponseList(@Nullable QueryResponse queryResp } @Nullable - public Collection createCollection(CreateCollectionRequest createCollectionRequest) { + public void createTenant(String tenantName) { + + this.restClient.post() + .uri("/api/v2/tenants") + .headers(this::httpHeaders) + .body(new CreateTenantRequest(tenantName)) + .retrieve() + .toBodilessEntity(); + } + + @Nullable + public Tenant getTenant(String tenantName) { + + try { + return this.restClient.get() + .uri("/api/v2/tenants/{tenant_name}", tenantName) + .headers(this::httpHeaders) + .retrieve() + .toEntity(Tenant.class) + .getBody(); + } + catch (HttpServerErrorException | HttpClientErrorException e) { + String msg = this.getErrorMessage(e); + if (String.format("Tenant [%s] not found", tenantName).equals(msg)) { + return null; + } + throw new RuntimeException(msg, e); + } + } + + @Nullable + public void createDatabase(String tenantName, String databaseName) { + + this.restClient.post() + .uri("/api/v2/tenants/{tenant_name}/databases", tenantName) + .headers(this::httpHeaders) + .body(new CreateDatabaseRequest(databaseName)) + .retrieve() + .toBodilessEntity(); + } + + @Nullable + public Database getDatabase(String tenantName, String databaseName) { + + try { + return this.restClient.get() + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}", tenantName, databaseName) + .headers(this::httpHeaders) + .retrieve() + .toEntity(Database.class) + .getBody(); + } + catch (HttpServerErrorException | HttpClientErrorException e) { + String msg = this.getErrorMessage(e); + if (msg.startsWith(String.format("Database [%s] not found.", databaseName))) { + return null; + } + throw new RuntimeException(msg, e); + } + } + + /** + * Delete a database with the given name. + * @param tenantName the name of the tenant to delete. + * @param databaseName the name of the database to delete. + */ + public void deleteDatabase(String tenantName, String databaseName) { + + this.restClient.delete() + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}", tenantName, databaseName) + .headers(this::httpHeaders) + .retrieve() + .toBodilessEntity(); + } + + @Nullable + public Collection createCollection(String tenantName, String databaseName, + CreateCollectionRequest createCollectionRequest) { return this.restClient.post() - .uri("/api/v1/collections") + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections", tenantName, databaseName) .headers(this::httpHeaders) .body(createCollectionRequest) .retrieve() @@ -132,21 +207,23 @@ public Collection createCollection(CreateCollectionRequest createCollectionReque * @param collectionName the name of the collection to delete. * */ - public void deleteCollection(String collectionName) { + public void deleteCollection(String tenantName, String databaseName, String collectionName) { this.restClient.delete() - .uri("/api/v1/collections/{collection_name}", collectionName) + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_name}", tenantName, + databaseName, collectionName) .headers(this::httpHeaders) .retrieve() .toBodilessEntity(); } @Nullable - public Collection getCollection(String collectionName) { + public Collection getCollection(String tenantName, String databaseName, String collectionName) { try { return this.restClient.get() - .uri("/api/v1/collections/{collection_name}", collectionName) + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_name}", + tenantName, databaseName, collectionName) .headers(this::httpHeaders) .retrieve() .toEntity(Collection.class) @@ -154,7 +231,7 @@ public Collection getCollection(String collectionName) { } catch (HttpServerErrorException | HttpClientErrorException e) { String msg = this.getErrorMessage(e); - if (String.format("Collection %s does not exist.", collectionName).equals(msg)) { + if (String.format("Collection [%s] does not exists", collectionName).equals(msg)) { return null; } throw new RuntimeException(msg, e); @@ -162,29 +239,33 @@ public Collection getCollection(String collectionName) { } @Nullable - public List listCollections() { + public List listCollections(String tenantName, String databaseName) { return this.restClient.get() - .uri("/api/v1/collections") + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections", tenantName, databaseName) .headers(this::httpHeaders) .retrieve() .toEntity(CollectionList.class) .getBody(); } - public void upsertEmbeddings(@Nullable String collectionId, AddEmbeddingsRequest embedding) { + public void upsertEmbeddings(String tenantName, String databaseName, String collectionId, + AddEmbeddingsRequest embedding) { this.restClient.post() - .uri("/api/v1/collections/{collection_id}/upsert", collectionId) + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_name}/upsert", + tenantName, databaseName, collectionId) .headers(this::httpHeaders) .body(embedding) .retrieve() .toBodilessEntity(); } - public int deleteEmbeddings(@Nullable String collectionId, DeleteEmbeddingsRequest deleteRequest) { + public int deleteEmbeddings(String tenantName, String databaseName, String collectionId, + DeleteEmbeddingsRequest deleteRequest) { return this.restClient.post() - .uri("/api/v1/collections/{collection_id}/delete", collectionId) + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_name}/delete", + tenantName, databaseName, collectionId) .headers(this::httpHeaders) .body(deleteRequest) .retrieve() @@ -194,10 +275,11 @@ public int deleteEmbeddings(@Nullable String collectionId, DeleteEmbeddingsReque } @Nullable - public Long countEmbeddings(String collectionId) { + public Long countEmbeddings(String tenantName, String databaseName, String collectionId) { return this.restClient.get() - .uri("/api/v1/collections/{collection_id}/count", collectionId) + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_id}/count", + tenantName, databaseName, collectionId) .headers(this::httpHeaders) .retrieve() .toEntity(Long.class) @@ -205,10 +287,12 @@ public Long countEmbeddings(String collectionId) { } @Nullable - public QueryResponse queryCollection(@Nullable String collectionId, QueryRequest queryRequest) { + public QueryResponse queryCollection(String tenantName, String databaseName, String collectionId, + QueryRequest queryRequest) { return this.restClient.post() - .uri("/api/v1/collections/{collection_id}/query", collectionId) + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_id}/query", + tenantName, databaseName, collectionId) .headers(this::httpHeaders) .body(queryRequest) .retrieve() @@ -220,10 +304,12 @@ public QueryResponse queryCollection(@Nullable String collectionId, QueryRequest // Chroma Client API (https://docs.trychroma.com/js_reference/Client) // @Nullable - public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) { + public GetEmbeddingResponse getEmbeddings(String tenantName, String databaseName, String collectionId, + GetEmbeddingsRequest getEmbeddingsRequest) { return this.restClient.post() - .uri("/api/v1/collections/{collection_id}/get", collectionId) + .uri("/api/v2/tenants/{tenant_name}/databases/{database_name}/collections/{collection_id}/get", tenantName, + databaseName, collectionId) .headers(this::httpHeaders) .body(getEmbeddingsRequest) .retrieve() @@ -271,6 +357,42 @@ private String getErrorMessage(HttpStatusCodeException e) { return ""; } + /** + * Request to create a new tenant + * + * @param name The name of the tenant to create. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record CreateTenantRequest(@JsonProperty("name") String name) { + } + + /** + * Chroma tenant. + * + * @param name The name of the tenant. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Tenant(@JsonProperty("name") String name) { + } + + /** + * Request to create a new database + * + * @param name The name of the database to create. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record CreateDatabaseRequest(@JsonProperty("name") String name) { + } + + /** + * Chroma database. + * + * @param name The name of the database. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Database(@JsonProperty("name") String name) { + } + /** * Chroma embedding collection. * @@ -487,4 +609,36 @@ private static class CollectionList extends ArrayList { } + public static class Builder { + + private String baseUrl = ChromaApiConstants.DEFAULT_BASE_URL; + + private RestClient.Builder restClientBuilder = RestClient.builder(); + + private ObjectMapper objectMapper = new ObjectMapper(); + + public Builder baseUrl(String baseUrl) { + Assert.hasText(baseUrl, "baseUrl cannot be null or empty"); + this.baseUrl = baseUrl; + return this; + } + + public Builder restClientBuilder(RestClient.Builder restClientBuilder) { + Assert.notNull(restClientBuilder, "restClientBuilder cannot be null"); + this.restClientBuilder = restClientBuilder; + return this; + } + + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.objectMapper = objectMapper; + return this; + } + + public ChromaApi build() { + return new ChromaApi(this.baseUrl, this.restClientBuilder, objectMapper); + } + + } + } diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java index 2e49f843c84..b56b99673a5 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java @@ -30,6 +30,7 @@ import org.springframework.ai.chroma.vectorstore.ChromaApi.AddEmbeddingsRequest; import org.springframework.ai.chroma.vectorstore.ChromaApi.DeleteEmbeddingsRequest; import org.springframework.ai.chroma.vectorstore.ChromaApi.Embedding; +import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; @@ -61,13 +62,16 @@ * @author Sebastien Deleuze * @author Soby Chacko * @author Thomas Vitale + * @author Jonghoon Park */ public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean { - public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection"; - private final ChromaApi chromaApi; + private final String tenantName; + + private final String databaseName; + private final String collectionName; private FilterExpressionConverter filterExpressionConverter; @@ -90,6 +94,8 @@ protected ChromaVectorStore(Builder builder) { super(builder); this.chromaApi = builder.chromaApi; + this.tenantName = builder.tenantName; + this.databaseName = builder.databaseName; this.collectionName = builder.collectionName; this.initializeSchema = builder.initializeSchema; this.filterExpressionConverter = builder.filterExpressionConverter; @@ -112,11 +118,21 @@ public static Builder builder(ChromaApi chromaApi, EmbeddingModel embeddingModel @Override public void afterPropertiesSet() throws Exception { if (!this.initialized) { - var collection = this.chromaApi.getCollection(this.collectionName); + var collection = this.chromaApi.getCollection(this.tenantName, this.databaseName, this.collectionName); if (collection == null) { if (this.initializeSchema) { - collection = this.chromaApi - .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName)); + var tenant = this.chromaApi.getTenant(this.tenantName); + if (tenant == null) { + this.chromaApi.createTenant(this.tenantName); + } + + var database = this.chromaApi.getDatabase(this.tenantName, this.databaseName); + if (database == null) { + this.chromaApi.createDatabase(this.tenantName, this.databaseName); + } + + collection = this.chromaApi.createCollection(this.tenantName, this.databaseName, + new ChromaApi.CreateCollectionRequest(this.collectionName)); } else { throw new RuntimeException("Collection " + this.collectionName @@ -152,14 +168,15 @@ public void doAdd(@NonNull List documents) { embeddings.add(documentEmbeddings.get(documents.indexOf(document))); } - this.chromaApi.upsertEmbeddings(this.collectionId, + this.chromaApi.upsertEmbeddings(this.tenantName, this.databaseName, this.collectionId, new AddEmbeddingsRequest(ids, embeddings, metadatas, contents)); } @Override public void doDelete(List idList) { Assert.notNull(idList, "Document id list must not be null"); - this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList)); + this.chromaApi.deleteEmbeddings(this.tenantName, this.databaseName, this.collectionId, + new DeleteEmbeddingsRequest(idList)); } @Override @@ -175,7 +192,7 @@ protected void doDelete(Filter.Expression expression) { logger.debug("Deleting with where clause: " + whereClause); DeleteEmbeddingsRequest deleteRequest = new DeleteEmbeddingsRequest(null, whereClause); - this.chromaApi.deleteEmbeddings(this.collectionId, deleteRequest); + this.chromaApi.deleteEmbeddings(this.tenantName, this.databaseName, this.collectionId, deleteRequest); } catch (Exception e) { logger.error("Failed to delete documents by filter: {}", e.getMessage(), e); @@ -196,7 +213,8 @@ public List doSimilaritySearch(@NonNull SearchRequest request) { ? jsonToMap(this.filterExpressionConverter.convertExpression(request.getFilterExpression())) : null; var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where); - var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest); + var queryResponse = this.chromaApi.queryCollection(this.tenantName, this.databaseName, this.collectionId, + queryRequest); var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse); List responseDocuments = new ArrayList<>(); @@ -242,11 +260,29 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str .collectionName(this.collectionName + ":" + this.collectionId); } + // used by the test + void createCollection() { + var collection = this.chromaApi.createCollection(this.tenantName, this.databaseName, + new ChromaApi.CreateCollectionRequest(this.collectionName)); + if (collection != null) { + this.collectionId = collection.id(); + } + } + + // used by the test + void deleteCollection() { + this.chromaApi.deleteCollection(this.tenantName, this.databaseName, this.collectionName); + } + public static class Builder extends AbstractVectorStoreBuilder { private final ChromaApi chromaApi; - private String collectionName = DEFAULT_COLLECTION_NAME; + private String tenantName = ChromaApiConstants.DEFAULT_TENANT_NAME; + + private String databaseName = ChromaApiConstants.DEFAULT_DATABASE_NAME; + + private String collectionName = ChromaApiConstants.DEFAULT_COLLECTION_NAME; private boolean initializeSchema = false; @@ -260,6 +296,30 @@ private Builder(ChromaApi chromaApi, EmbeddingModel embeddingModel) { this.chromaApi = chromaApi; } + /** + * Sets the tenant name. + * @param tenantName the name of the tenant + * @return the builder instance + * @throws IllegalArgumentException if collectionName is null or empty + */ + public Builder tenantName(String tenantName) { + Assert.hasText(tenantName, "tenantName must not be null or empty"); + this.tenantName = tenantName; + return this; + } + + /** + * Sets the database name. + * @param databaseName the name of the database + * @return the builder instance + * @throws IllegalArgumentException if collectionName is null or empty + */ + public Builder databaseName(String databaseName) { + Assert.hasText(databaseName, "databaseName must not be null or empty"); + this.databaseName = databaseName; + return this; + } + /** * Sets the collection name. * @param collectionName the name of the collection diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/common/ChromaApiConstants.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/common/ChromaApiConstants.java new file mode 100644 index 00000000000..80e01975e42 --- /dev/null +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/common/ChromaApiConstants.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chroma.vectorstore.common; + +/** + * Common value constants for Chroma api. + * + * @author Jonghoon Park + */ +public class ChromaApiConstants { + + public static final String DEFAULT_BASE_URL = "http://localhost:8000"; + + public static final String DEFAULT_TENANT_NAME = "SpringAiTenant"; + + public static final String DEFAULT_DATABASE_NAME = "SpringAiDatabase"; + + public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection"; + +} diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaImage.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaImage.java index a3484542682..4893958b3f6 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaImage.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/ChromaImage.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ */ public final class ChromaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:1.0.0"); private ChromaImage() { diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java index 88ef6a80ca3..f6216e05bd5 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy; +import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.utility.MountableFile; @@ -51,6 +53,7 @@ * @author Christian Tzolov * @author Eddú Meléndez * @author Thomas Vitale + * @author Jonghoon Park */ @Testcontainers @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -108,7 +111,11 @@ public RestClient.Builder builder() { @Bean public ChromaApi chromaApi(RestClient.Builder builder) { - return new ChromaApi(chromaContainer.getEndpoint(), builder).withBasicAuthCredentials("admin", "password"); + return ChromaApi.builder() + .baseUrl(chromaContainer.getEndpoint()) + .restClientBuilder(builder) + .build() + .withBasicAuthCredentials("admin", "password"); } @Bean diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaApiIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaApiIT.java index 270dcd30612..5d59526342d 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaApiIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants; import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy; +import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -48,6 +51,7 @@ * @author Eddú Meléndez * @author Thomas Vitale * @author Soby Chacko + * @author Jonghoon Park */ @SpringBootTest @Testcontainers @@ -56,6 +60,10 @@ public class ChromaApiIT { @Container static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE); + final String defaultTenantName = ChromaApiConstants.DEFAULT_TENANT_NAME; + + final String defaultDatabaseName = ChromaApiConstants.DEFAULT_DATABASE_NAME; + @Autowired ChromaApi chromaApi; @@ -64,57 +72,74 @@ public class ChromaApiIT { @BeforeEach public void beforeEach() { - this.chromaApi.listCollections().stream().forEach(c -> this.chromaApi.deleteCollection(c.name())); + var tenant = this.chromaApi.getTenant(defaultTenantName); + if (tenant == null) { + this.chromaApi.createTenant(defaultTenantName); + } + + var database = this.chromaApi.getDatabase(defaultTenantName, defaultDatabaseName); + if (database == null) { + this.chromaApi.createDatabase(defaultTenantName, defaultDatabaseName); + } + + this.chromaApi.listCollections(defaultTenantName, defaultDatabaseName) + .forEach(c -> this.chromaApi.deleteCollection(defaultTenantName, defaultDatabaseName, c.name())); } @Test public void testClientWithMetadata() { Map metadata = Map.of("hnsw:space", "cosine", "hnsw:M", 5); - var newCollection = this.chromaApi - .createCollection(new ChromaApi.CreateCollectionRequest("TestCollection", metadata)); + var newCollection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName, + new ChromaApi.CreateCollectionRequest("TestCollection", metadata)); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); } @Test public void testClient() { - var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var newCollection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName, + new ChromaApi.CreateCollectionRequest("TestCollection")); assertThat(newCollection).isNotNull(); assertThat(newCollection.name()).isEqualTo("TestCollection"); - var getCollection = this.chromaApi.getCollection("TestCollection"); + var getCollection = this.chromaApi.getCollection(defaultTenantName, defaultDatabaseName, "TestCollection"); assertThat(getCollection).isNotNull(); assertThat(getCollection.name()).isEqualTo("TestCollection"); assertThat(getCollection.id()).isEqualTo(newCollection.id()); - List collections = this.chromaApi.listCollections(); + List collections = this.chromaApi.listCollections(defaultTenantName, defaultDatabaseName); assertThat(collections).hasSize(1); assertThat(collections.get(0).id()).isEqualTo(newCollection.id()); - this.chromaApi.deleteCollection(newCollection.name()); - assertThat(this.chromaApi.listCollections()).hasSize(0); + this.chromaApi.deleteCollection(defaultTenantName, defaultDatabaseName, newCollection.name()); + assertThat(this.chromaApi.listCollections(defaultTenantName, defaultDatabaseName)).hasSize(0); } @Test public void testCollection() { - var newCollection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); - assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(0); + var newCollection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName, + new ChromaApi.CreateCollectionRequest("TestCollection")); + assertThat(this.chromaApi.countEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id())) + .isEqualTo(0); var addEmbeddingRequest = new AddEmbeddingsRequest(List.of("id1", "id2"), List.of(new float[] { 1f, 1f, 1f }, new float[] { 2f, 2f, 2f }), List.of(Map.of(), Map.of("key1", "value1", "key2", true, "key3", 23.4)), List.of("Hello World", "Big World")); - this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest); + this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id(), + addEmbeddingRequest); var addEmbeddingRequest2 = new AddEmbeddingsRequest("id3", new float[] { 3f, 3f, 3f }, Map.of("key1", "value1", "key2", true, "key3", 23.4), "Big World"); - this.chromaApi.upsertEmbeddings(newCollection.id(), addEmbeddingRequest2); + this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id(), + addEmbeddingRequest2); - assertThat(this.chromaApi.countEmbeddings(newCollection.id())).isEqualTo(3); + assertThat(this.chromaApi.countEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id())) + .isEqualTo(3); - var queryResult = this.chromaApi.queryCollection(newCollection.id(), + var queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, newCollection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "key2" : { "$eq": true } @@ -124,13 +149,15 @@ public void testCollection() { assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id2", "id3"); // Update existing embedding. - this.chromaApi.upsertEmbeddings(newCollection.id(), new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, - Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World")); + this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id(), + new AddEmbeddingsRequest("id3", new float[] { 6f, 6f, 6f }, + Map.of("key1", "value2", "key2", false, "key4", 23.4), "Small World")); - var result = this.chromaApi.getEmbeddings(newCollection.id(), new GetEmbeddingsRequest(List.of("id2"))); + var result = this.chromaApi.getEmbeddings(defaultTenantName, defaultDatabaseName, newCollection.id(), + new GetEmbeddingsRequest(List.of("id2"))); assertThat(result.ids().get(0)).isEqualTo("id2"); - queryResult = this.chromaApi.queryCollection(newCollection.id(), + queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, newCollection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "key2" : { "$eq": true } @@ -143,7 +170,8 @@ public void testCollection() { @Test public void testQueryWhere() { - var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("TestCollection")); + var collection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName, + new ChromaApi.CreateCollectionRequest("TestCollection")); var add1 = new AddEmbeddingsRequest("id1", new float[] { 1f, 1f, 1f }, Map.of("country", "BG", "active", true, "price", 23.4, "year", 2020), @@ -156,13 +184,14 @@ public void testQueryWhere() { Map.of("country", "BG", "active", false, "price", 40.1, "year", 2023), "The World is Big and Salvation Lurks Around the Corner"); - this.chromaApi.upsertEmbeddings(collection.id(), add1); - this.chromaApi.upsertEmbeddings(collection.id(), add2); - this.chromaApi.upsertEmbeddings(collection.id(), add3); + this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, collection.id(), add1); + this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, collection.id(), add2); + this.chromaApi.upsertEmbeddings(defaultTenantName, defaultDatabaseName, collection.id(), add3); - assertThat(this.chromaApi.countEmbeddings(collection.id())).isEqualTo(3); + assertThat(this.chromaApi.countEmbeddings(defaultTenantName, defaultDatabaseName, collection.id())) + .isEqualTo(3); - var queryResult = this.chromaApi.queryCollection(collection.id(), + var queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3)); assertThat(queryResult.ids().get(0)).hasSize(3); @@ -173,7 +202,7 @@ public void testQueryWhere() { assertThat(chromaEmbeddings).hasSize(3); assertThat(chromaEmbeddings).hasSize(3); - queryResult = this.chromaApi.queryCollection(collection.id(), + queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "$and" : [ @@ -185,7 +214,7 @@ public void testQueryWhere() { assertThat(queryResult.ids().get(0)).hasSize(2); assertThat(queryResult.ids().get(0)).containsExactlyInAnyOrder("id1", "id3"); - queryResult = this.chromaApi.queryCollection(collection.id(), + queryResult = this.chromaApi.queryCollection(defaultTenantName, defaultDatabaseName, collection.id(), new QueryRequest(new float[] { 1f, 1f, 1f }, 3, this.chromaApi.where(""" { "$and" : [ @@ -203,7 +232,8 @@ public void testQueryWhere() { void shouldUseExistingCollectionWhenSchemaInitializationDisabled() { // initializeSchema // is false by // default. - var collection = this.chromaApi.createCollection(new ChromaApi.CreateCollectionRequest("test-collection")); + var collection = this.chromaApi.createCollection(defaultTenantName, defaultDatabaseName, + new ChromaApi.CreateCollectionRequest("test-collection")); assertThat(collection).isNotNull(); assertThat(collection.name()).isEqualTo("test-collection"); @@ -224,7 +254,7 @@ void shouldCreateNewCollectionWhenSchemaInitializationEnabled() { .initializeImmediately(true) .build(); - var collection = this.chromaApi.getCollection("new-collection"); + var collection = this.chromaApi.getCollection(defaultTenantName, defaultDatabaseName, "new-collection"); assertThat(collection).isNotNull(); assertThat(collection.name()).isEqualTo("new-collection"); @@ -250,7 +280,7 @@ public static class Config { @Bean public ChromaApi chromaApi() { - return new ChromaApi(chromaContainer.getEndpoint()); + return ChromaApi.builder().baseUrl(chromaContainer.getEndpoint()).build(); } @Bean diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java index 9ab4bc3ac14..b3252c67341 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java @@ -25,6 +25,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy; +import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -49,6 +51,7 @@ * @author Christian Tzolov * @author Eddú Meléndez * @author Thomas Vitale + * @author Jonghoon Park */ @Testcontainers @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -57,6 +60,11 @@ public class ChromaVectorStoreIT extends BaseVectorStoreTests { @Container static ChromaDBContainer chromaContainer = new ChromaDBContainer(ChromaImage.DEFAULT_IMAGE); + private void resetCollection(VectorStore vectorStore) { + ((ChromaVectorStore) vectorStore).deleteCollection(); + ((ChromaVectorStore) vectorStore).createCollection(); + } + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(TestApplication.class) .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")); @@ -83,6 +91,8 @@ public void addAndSearch() { VectorStore vectorStore = context.getBean(VectorStore.class); + resetCollection(vectorStore); + vectorStore.add(this.documents); List results = vectorStore @@ -110,6 +120,8 @@ public void simpleSearch() { VectorStore vectorStore = context.getBean(VectorStore.class); + resetCollection(vectorStore); + var document = Document.builder() .id("simpleDoc") .text("The sky is blue because of Rayleigh scattering.") @@ -139,6 +151,8 @@ public void addAndSearchWithFilters() { VectorStore vectorStore = context.getBean(VectorStore.class); + resetCollection(vectorStore); + var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", Map.of("country", "Bulgaria")); var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", @@ -185,6 +199,8 @@ public void documentUpdateTest() { VectorStore vectorStore = context.getBean(VectorStore.class); + resetCollection(vectorStore); + Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", Collections.singletonMap("meta1", "meta1")); @@ -227,6 +243,8 @@ public void searchThresholdTest() { VectorStore vectorStore = context.getBean(VectorStore.class); + resetCollection(vectorStore); + vectorStore.add(this.documents); var request = SearchRequest.builder().query("Great").topK(5).build(); @@ -266,7 +284,7 @@ public RestClient.Builder builder() { @Bean public ChromaApi chromaApi(RestClient.Builder builder) { - return new ChromaApi(chromaContainer.getEndpoint(), builder); + return ChromaApi.builder().baseUrl(chromaContainer.getEndpoint()).restClientBuilder(builder).build(); } @Bean diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java index ef28e268473..693973617b1 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy; +import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -57,6 +59,7 @@ /** * @author Christian Tzolov * @author Thomas Vitale + * @author Jonghoon Park */ @Testcontainers @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -170,7 +173,7 @@ public RestClient.Builder builder() { @Bean public ChromaApi chromaApi(RestClient.Builder builder) { - return new ChromaApi(chromaContainer.getEndpoint(), builder); + return ChromaApi.builder().baseUrl(chromaContainer.getEndpoint()).restClientBuilder(builder).build(); } @Bean diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java index 00997ae4f44..73b38911b07 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java @@ -22,6 +22,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.testcontainers.chromadb.ChromaDBContainer; +import org.testcontainers.containers.wait.strategy.AbstractWaitStrategy; +import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -143,7 +145,10 @@ public RestClient.Builder builder() { @Bean public ChromaApi chromaApi(RestClient.Builder builder) { - var chromaApi = new ChromaApi(chromaContainer.getEndpoint(), builder); + var chromaApi = ChromaApi.builder() + .baseUrl(chromaContainer.getEndpoint()) + .restClientBuilder(builder) + .build(); chromaApi.withKeyToken(CHROMA_SERVER_AUTH_CREDENTIALS); return chromaApi; } diff --git a/vector-stores/spring-ai-chroma-store/src/test/resources/api.yaml b/vector-stores/spring-ai-chroma-store/src/test/resources/api.yaml deleted file mode 100644 index 924c8a046d1..00000000000 --- a/vector-stores/spring-ai-chroma-store/src/test/resources/api.yaml +++ /dev/null @@ -1,630 +0,0 @@ -openapi: "3.0.0" -info: - title: FastAPI - version: 0.1.0 -paths: - /api/v1: - get: - summary: Root - operationId: root - responses: - '200': - description: Successful Response - content: - application/json: - schema: - additionalProperties: - type: integer - type: object - title: Response Root Api V1 Get - /api/v1/reset: - post: - summary: Reset - operationId: reset - responses: - '200': - description: Successful Response - content: - application/json: - schema: - type: boolean - title: Response Reset Api V1 Reset Post - /api/v1/version: - get: - summary: Version - operationId: version - responses: - '200': - description: Successful Response - content: - application/json: - schema: - type: string - title: Response Version Api V1 Version Get - /api/v1/heartbeat: - get: - summary: Heartbeat - operationId: heartbeat - responses: - '200': - description: Successful Response - content: - application/json: - schema: - additionalProperties: - type: number - type: object - title: Response Heartbeat Api V1 Heartbeat Get - /api/v1/raw_sql: - post: - summary: Raw Sql - operationId: raw_sql - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RawSql' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections: - get: - summary: List Collections - operationId: list_collections - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - post: - summary: Create Collection - operationId: create_collection - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/CreateCollection' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_id}/add: - post: - summary: Add - operationId: add - parameters: - - required: true - schema: - type: string - title: Collection Id - name: collection_id - in: path - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/AddEmbedding' - required: true - responses: - '201': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_id}/update: - post: - summary: Update - operationId: update - parameters: - - required: true - schema: - type: string - title: Collection Id - name: collection_id - in: path - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/UpdateEmbedding' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_id}/upsert: - post: - summary: Upsert - operationId: upsert - parameters: - - required: true - schema: - type: string - title: Collection Id - name: collection_id - in: path - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/AddEmbedding' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_id}/get: - post: - summary: Get - operationId: get - parameters: - - required: true - schema: - type: string - title: Collection Id - name: collection_id - in: path - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/GetEmbedding' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} #TODO add actual GetResult Body - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_id}/delete: - post: - summary: Delete - operationId: delete - parameters: - - required: true - schema: - type: string - title: Collection Id - name: collection_id - in: path - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/DeleteEmbedding' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_id}/count: - get: - summary: Count - operationId: count - parameters: - - required: true - schema: - type: string - title: Collection Id - name: collection_id - in: path - responses: - '200': - description: Successful Response - content: - application/json: - schema: - type: integer - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_id}/query: - post: - summary: Get Nearest Neighbors - operationId: get_nearest_neighbors - parameters: - - required: true - schema: - type: string - title: Collection Id - name: collection_id - in: path - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/QueryEmbedding' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_name}/create_index: - post: - summary: Create Index - operationId: create_index - parameters: - - required: true - schema: - type: string - title: Collection Name - name: collection_name - in: path - responses: - '200': - description: Successful Response - content: - application/json: - schema: - type: boolean - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_name}: - get: - summary: Get Collection - operationId: get_collection - parameters: - - required: true - schema: - type: string - title: Collection Name - name: collection_name - in: path - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - delete: - summary: Delete Collection - operationId: delete_collection - parameters: - - required: true - schema: - type: string - title: Collection Name - name: collection_name - in: path - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - /api/v1/collections/{collection_id}: - put: - summary: Update Collection - operationId: update_collection - parameters: - - required: true - schema: - type: string - title: Collection Id - name: collection_id - in: path - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/UpdateCollection' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: {} - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' -components: - schemas: - AddEmbedding: - properties: - embeddings: - items: {} - type: array - title: Embeddings - metadatas: - items: - type: object - additionalProperties: true - type: array - title: Metadatas - documents: - items: - type: string - type: array - title: Documents - ids: - items: - type: string - type: array - title: Ids - increment_index: - type: boolean - title: Increment Index - default: true - type: object - required: - - ids - title: AddEmbedding - CreateCollection: - properties: - name: - type: string - title: Name - metadata: - type: object - title: Metadata - get_or_create: - type: boolean - title: Get Or Create - default: false - type: object - required: - - name - title: CreateCollection - DeleteEmbedding: - properties: - ids: - items: - type: string - type: array - title: Ids - where: - type: object - title: Where - where_document: - type: object - title: Where Document - type: object - title: DeleteEmbedding - GetEmbedding: - properties: - ids: - items: - type: string - type: array - title: Ids - where: - type: object - title: Where - where_document: - type: object - title: Where Document - sort: - type: string - title: Sort - limit: - type: integer - title: Limit - offset: - type: integer - title: Offset - include: - items: - anyOf: - - type: string - enum: - - documents - - type: string - enum: - - embeddings - - type: string - enum: - - metadatas - - type: string - enum: - - distances - type: array - title: Include - default: - - metadatas - - documents - type: object - title: GetEmbedding - HTTPValidationError: - properties: - detail: - items: - $ref: '#/components/schemas/ValidationError' - type: array - title: Detail - type: object - title: HTTPValidationError - QueryEmbedding: - properties: - where: - type: object - title: Where - additionalProperties: true - default: {} - where_document: - type: object - title: Where Document - additionalProperties: true - default: {} - query_embeddings: - items: {} - type: array - additionalProperties: true - title: Query Embeddings - n_results: - type: integer - title: N Results - default: 10 - include: - items: - type: string - enum: - - documents - - embeddings - - metadatas - - distances - type: array - title: Include - default: - - metadatas - - documents - - distances - type: object - required: - - query_embeddings - title: QueryEmbedding - RawSql: - properties: - raw_sql: - type: string - title: Raw Sql - type: object - required: - - raw_sql - title: RawSql - UpdateCollection: - properties: - new_name: - type: string - title: New Name - new_metadata: - type: object - title: New Metadata - type: object - title: UpdateCollection - UpdateEmbedding: - properties: - embeddings: - items: {} - type: array - title: Embeddings - metadatas: - items: - type: object - type: array - title: Metadatas - documents: - items: - type: string - type: array - title: Documents - ids: - items: - type: string - type: array - title: Ids - increment_index: - type: boolean - title: Increment Index - default: true - type: object - required: - - ids - title: UpdateEmbedding - ValidationError: - properties: - loc: - items: - anyOf: - - type: string - - type: integer - type: array - title: Location - msg: - type: string - title: Message - type: - type: string - title: Error Type - type: object - required: - - loc - - msg - - type - title: ValidationError \ No newline at end of file