Skip to content

Commit a2197e1

Browse files
authored
Add support to pull model for DockerModelRunnerContainer (#10253)
1 parent 77154e1 commit a2197e1

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

core/src/main/java/org/testcontainers/containers/DockerModelRunnerContainer.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
package org.testcontainers.containers;
22

3+
import com.github.dockerjava.api.command.InspectContainerResponse;
4+
import lombok.extern.slf4j.Slf4j;
35
import org.testcontainers.containers.wait.strategy.Wait;
46
import org.testcontainers.utility.DockerImageName;
57

8+
import java.io.BufferedReader;
9+
import java.io.IOException;
10+
import java.io.InputStreamReader;
11+
import java.io.OutputStream;
12+
import java.net.HttpURLConnection;
13+
import java.net.URL;
14+
import java.nio.charset.StandardCharsets;
15+
616
/**
717
* Testcontainers proxy container for the Docker Model Runner service
818
* provided by Docker Desktop.
@@ -11,12 +21,15 @@
1121
* <p>
1222
* Exposed ports: 80
1323
*/
24+
@Slf4j
1425
public class DockerModelRunnerContainer extends SocatContainer {
1526

1627
private static final String MODEL_RUNNER_ENDPOINT = "model-runner.docker.internal";
1728

1829
private static final int PORT = 80;
1930

31+
private String model;
32+
2033
public DockerModelRunnerContainer(String image) {
2134
this(DockerImageName.parse(image));
2235
}
@@ -27,6 +40,45 @@ public DockerModelRunnerContainer(DockerImageName image) {
2740
waitingFor(Wait.forHttp("/").forResponsePredicate(res -> res.contains("The service is running")));
2841
}
2942

43+
@Override
44+
protected void containerIsStarted(InspectContainerResponse containerInfo) {
45+
if (this.model != null) {
46+
logger().info("Pulling model: {}. Please be patient.", this.model);
47+
48+
String url = getBaseEndpoint() + "/models/create";
49+
String payload = String.format("{\"from\": \"%s\"}", this.model);
50+
51+
try {
52+
HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();
53+
connection.setRequestMethod("POST");
54+
connection.setRequestProperty("Content-Type", "application/json");
55+
connection.setDoOutput(true);
56+
57+
try (OutputStream os = connection.getOutputStream()) {
58+
os.write(payload.getBytes());
59+
os.flush();
60+
}
61+
62+
try (
63+
BufferedReader br = new BufferedReader(
64+
new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8)
65+
)
66+
) {
67+
while (br.readLine() != null) {}
68+
}
69+
connection.disconnect();
70+
} catch (IOException e) {
71+
logger().error("Failed to pull model {}: {}", this.model, e);
72+
}
73+
logger().info("Finished pulling model: {}", this.model);
74+
}
75+
}
76+
77+
public DockerModelRunnerContainer withModel(String model) {
78+
this.model = model;
79+
return this;
80+
}
81+
3082
public String getBaseEndpoint() {
3183
return "http://" + getHost() + ":" + getMappedPort(PORT);
3284
}

core/src/test/java/org/testcontainers/containers/DockerModelRunnerContainerTest.java

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,34 @@
1010
public class DockerModelRunnerContainerTest {
1111

1212
@Test
13-
public void pullsModelAndExposesInference() {
13+
public void checkStatus() {
1414
assumeThat(System.getenv("CI")).isNull();
1515

16-
String modelName = "ai/smollm2:360M-Q4_K_M";
17-
1816
try (
1917
// container {
2018
DockerModelRunnerContainer dmr = new DockerModelRunnerContainer("alpine/socat:1.7.4.3-r0")
2119
// }
2220
) {
2321
dmr.start();
2422

23+
Response modelResponse = RestAssured.get(dmr.getBaseEndpoint() + "/status").thenReturn();
24+
assertThat(modelResponse.body().asString()).contains("The service is running");
25+
}
26+
}
27+
28+
@Test
29+
public void pullsModelAndExposesInference() {
30+
assumeThat(System.getenv("CI")).isNull();
31+
32+
String modelName = "ai/smollm2:360M-Q4_K_M";
33+
34+
try (
2535
// pullModel {
26-
RestAssured
27-
.given()
28-
.body(String.format("{\"from\":\"%s\"}", modelName))
29-
.post(dmr.getBaseEndpoint() + "/models/create")
30-
.then()
31-
.statusCode(200);
36+
DockerModelRunnerContainer dmr = new DockerModelRunnerContainer("alpine/socat:1.7.4.3-r0")
37+
.withModel(modelName)
3238
// }
39+
) {
40+
dmr.start();
3341

3442
Response modelResponse = RestAssured.get(dmr.getBaseEndpoint() + "/models").thenReturn();
3543
assertThat(modelResponse.body().jsonPath().getList("tags.flatten()")).contains(modelName);

0 commit comments

Comments
 (0)