|
1 | 1 | package org.testcontainers.containers; |
2 | 2 |
|
| 3 | +import com.github.dockerjava.api.command.InspectContainerResponse; |
3 | 4 | import org.testcontainers.containers.wait.strategy.Wait; |
| 5 | +import org.testcontainers.utility.DockerImageName; |
4 | 6 |
|
5 | 7 | import java.io.BufferedReader; |
6 | 8 | import java.io.IOException; |
|
10 | 12 | import java.net.URL; |
11 | 13 | import java.nio.charset.StandardCharsets; |
12 | 14 |
|
13 | | -public class DockerModelRunnerContainer extends GenericContainer<DockerModelRunnerContainer> { |
| 15 | +/** |
| 16 | + * Testcontainers proxy container for the Docker Model Runner service |
| 17 | + * provided by Docker Desktop. |
| 18 | + * <p> |
| 19 | + * Supported images: {@code alpine/socat} |
| 20 | + * <p> |
| 21 | + * Exposed ports: 80 |
| 22 | + */ |
| 23 | +public class DockerModelRunnerContainer extends SocatContainer { |
14 | 24 |
|
15 | 25 | public static final String MODEL_RUNNER_ENDPOINT = "model-runner.docker.internal"; |
16 | 26 |
|
17 | | - private SocatContainer socat; |
18 | | - |
19 | 27 | private String model; |
20 | 28 |
|
| 29 | + public DockerModelRunnerContainer(String image) { |
| 30 | + this(DockerImageName.parse(image)); |
| 31 | + } |
| 32 | + |
| 33 | + public DockerModelRunnerContainer(DockerImageName image) { |
| 34 | + super(image); |
| 35 | + withTarget(80, MODEL_RUNNER_ENDPOINT); |
| 36 | + waitingFor(Wait.forHttp("/").forResponsePredicate(res -> res.contains("The service is running"))); |
| 37 | + } |
| 38 | + |
21 | 39 | @Override |
22 | | - public void start() { |
23 | | - this.socat = |
24 | | - new SocatContainer() |
25 | | - .withTarget(80, MODEL_RUNNER_ENDPOINT) |
26 | | - .waitingFor(Wait.forHttp("/").forResponsePredicate(res -> res.contains("The service is running"))); |
27 | | - this.socat.start(); |
| 40 | + protected void containerIsStarted(InspectContainerResponse containerInfo) { |
28 | 41 | pullModel(); |
29 | 42 | } |
30 | 43 |
|
31 | 44 | private void pullModel() { |
32 | 45 | logger().info("Pulling model: {}. Please be patient, no progress bar yet!", this.model); |
33 | 46 | try { |
34 | 47 | String json = String.format("{\"from\":\"%s\"}", this.model); |
35 | | - String endpoint = "http://" + this.socat.getHost() + ":" + this.socat.getMappedPort(80) + "/models/create"; |
| 48 | + String endpoint = "http://" + getHost() + ":" + getMappedPort(80) + "/models/create"; |
36 | 49 |
|
37 | 50 | URL url = new URL(endpoint); |
38 | 51 | HttpURLConnection connection = (HttpURLConnection) url.openConnection(); |
@@ -63,13 +76,8 @@ private void pullModel() { |
63 | 76 | logger().info("Finished pulling model: {}", model); |
64 | 77 | } |
65 | 78 |
|
66 | | - @Override |
67 | | - public void stop() { |
68 | | - this.socat.stop(); |
69 | | - } |
70 | | - |
71 | 79 | public String getBaseEndpoint() { |
72 | | - return "http://" + this.socat.getHost() + ":" + this.socat.getMappedPort(80); |
| 80 | + return "http://" + getHost() + ":" + getMappedPort(80); |
73 | 81 | } |
74 | 82 |
|
75 | 83 | public String getOpenAIEndpoint() { |
|
0 commit comments