diff --git a/modules/chromadb/src/main/java/org/testcontainers/chromadb/ChromaDBContainer.java b/modules/chromadb/src/main/java/org/testcontainers/chromadb/ChromaDBContainer.java index a1bccf3904f..af6c3df33fc 100644 --- a/modules/chromadb/src/main/java/org/testcontainers/chromadb/ChromaDBContainer.java +++ b/modules/chromadb/src/main/java/org/testcontainers/chromadb/ChromaDBContainer.java @@ -1,7 +1,9 @@ package org.testcontainers.chromadb; +import lombok.extern.slf4j.Slf4j; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.utility.ComparableVersion; import org.testcontainers.utility.DockerImageName; /** @@ -11,6 +13,7 @@ *

* Exposed ports: 8000 */ +@Slf4j public class ChromaDBContainer extends GenericContainer { private static final DockerImageName DEFAULT_DOCKER_IMAGE = DockerImageName.parse("chromadb/chroma"); @@ -22,13 +25,32 @@ public ChromaDBContainer(String dockerImageName) { } public ChromaDBContainer(DockerImageName dockerImageName) { + this(dockerImageName, isVersion2(dockerImageName.getVersionPart())); + } + + public ChromaDBContainer(DockerImageName dockerImageName, boolean isVersion2) { super(dockerImageName); + String apiPath = isVersion2 ? "/api/v2/heartbeat" : "/api/v1/heartbeat"; dockerImageName.assertCompatibleWith(DEFAULT_DOCKER_IMAGE, GHCR_DOCKER_IMAGE); withExposedPorts(8000); - waitingFor(Wait.forHttp("/api/v1/heartbeat")); + waitingFor(Wait.forHttp(apiPath)); } public String getEndpoint() { return "http://" + getHost() + ":" + getFirstMappedPort(); } + + private static boolean isVersion2(String version) { + if (version.equals("latest")) { + return true; + } + + ComparableVersion comparableVersion = new ComparableVersion(version); + if (comparableVersion.isGreaterThanOrEqualTo("1.0.0")) { + return true; + } + + log.warn("Version {} is less than 1.0.0 or not a semantic version.", version); + return false; + } } diff --git a/modules/chromadb/src/test/java/org/testcontainers/chromadb/ChromaDBContainerTest.java b/modules/chromadb/src/test/java/org/testcontainers/chromadb/ChromaDBContainerTest.java index 6cc01ac4d59..0ec6b00601c 100644 --- a/modules/chromadb/src/test/java/org/testcontainers/chromadb/ChromaDBContainerTest.java +++ b/modules/chromadb/src/test/java/org/testcontainers/chromadb/ChromaDBContainerTest.java @@ -27,4 +27,22 @@ public void test() { given().baseUri(chroma.getEndpoint()).when().get("/api/v1/databases/test").then().statusCode(200); } } + + @Test + public void testVersion2() { + try (ChromaDBContainer chroma = new ChromaDBContainer("chromadb/chroma:1.0.0")) { + chroma.start(); + + given() + .baseUri(chroma.getEndpoint()) + .when() + .body("{\"name\": \"test\"}") + .contentType(ContentType.JSON) + .post("/api/v2/tenants") + .then() + .statusCode(200); + + given().baseUri(chroma.getEndpoint()).when().get("/api/v2/tenants/test").then().statusCode(200); + } + } }