Skip to content

Commit 496465b

Browse files
committed
Polish
1 parent 6df84ad commit 496465b

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

core/src/main/java/org/testcontainers/containers/DockerModelRunnerContainer.java

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package org.testcontainers.containers;
22

3+
import com.github.dockerjava.api.command.InspectContainerResponse;
34
import org.testcontainers.containers.wait.strategy.Wait;
5+
import org.testcontainers.utility.DockerImageName;
46

57
import java.io.BufferedReader;
68
import java.io.IOException;
@@ -10,29 +12,40 @@
1012
import java.net.URL;
1113
import java.nio.charset.StandardCharsets;
1214

13-
public class DockerModelRunnerContainer extends GenericContainer<DockerModelRunnerContainer> {
15+
/**
16+
* Testcontainers proxy container for the Docker Model Runner service
17+
* provided by Docker Desktop.
18+
* <p>
19+
* Supported images: {@code alpine/socat}
20+
* <p>
21+
* Exposed ports: 80
22+
*/
23+
public class DockerModelRunnerContainer extends SocatContainer {
1424

1525
public static final String MODEL_RUNNER_ENDPOINT = "model-runner.docker.internal";
1626

17-
private SocatContainer socat;
18-
1927
private String model;
2028

29+
public DockerModelRunnerContainer(String image) {
30+
this(DockerImageName.parse(image));
31+
}
32+
33+
public DockerModelRunnerContainer(DockerImageName image) {
34+
super(image);
35+
withTarget(80, MODEL_RUNNER_ENDPOINT);
36+
waitingFor(Wait.forHttp("/").forResponsePredicate(res -> res.contains("The service is running")));
37+
}
38+
2139
@Override
22-
public void start() {
23-
this.socat =
24-
new SocatContainer()
25-
.withTarget(80, MODEL_RUNNER_ENDPOINT)
26-
.waitingFor(Wait.forHttp("/").forResponsePredicate(res -> res.contains("The service is running")));
27-
this.socat.start();
40+
protected void containerIsStarted(InspectContainerResponse containerInfo) {
2841
pullModel();
2942
}
3043

3144
private void pullModel() {
3245
logger().info("Pulling model: {}. Please be patient, no progress bar yet!", this.model);
3346
try {
3447
String json = String.format("{\"from\":\"%s\"}", this.model);
35-
String endpoint = "http://" + this.socat.getHost() + ":" + this.socat.getMappedPort(80) + "/models/create";
48+
String endpoint = "http://" + getHost() + ":" + getMappedPort(80) + "/models/create";
3649

3750
URL url = new URL(endpoint);
3851
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
@@ -63,13 +76,8 @@ private void pullModel() {
6376
logger().info("Finished pulling model: {}", model);
6477
}
6578

66-
@Override
67-
public void stop() {
68-
this.socat.stop();
69-
}
70-
7179
public String getBaseEndpoint() {
72-
return "http://" + this.socat.getHost() + ":" + this.socat.getMappedPort(80);
80+
return "http://" + getHost() + ":" + getMappedPort(80);
7381
}
7482

7583
public String getOpenAIEndpoint() {

0 commit comments

Comments
 (0)