Skip to content

Commit 2b08e6d

Browse files
committed
Add DockerModelRunnerContainer to core
1 parent e730674 commit 2b08e6d

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

core/build.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ dependencies {
123123
testImplementation 'org.assertj:assertj-core:3.26.3'
124124
testImplementation 'io.rest-assured:rest-assured:5.5.0'
125125

126+
// DockerModelRunnerContainer tests
127+
testImplementation("com.openai:openai-java:1.3.0")
128+
126129
jarFileTestCompileOnly "org.projectlombok:lombok:${lombok.version}"
127130
jarFileTestAnnotationProcessor "org.projectlombok:lombok:${lombok.version}"
128131
jarFileTestImplementation 'junit:junit:4.13.2'
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package org.testcontainers.containers;
2+
3+
import java.io.BufferedReader;
4+
import java.io.IOException;
5+
import java.io.InputStreamReader;
6+
import java.io.OutputStream;
7+
import java.net.HttpURLConnection;
8+
import java.net.URL;
9+
import java.nio.charset.StandardCharsets;
10+
11+
public class DockerModelRunnerContainer extends GenericContainer {
12+
13+
public static final String MODEL_RUNNER_ENDPOINT = "model-runner.docker.internal";
14+
private SocatContainer socat;
15+
private String model;
16+
17+
@Override
18+
public void start() {
19+
socat = new SocatContainer()
20+
.withTarget(80, MODEL_RUNNER_ENDPOINT, 80);
21+
socat.start();
22+
pullModel();
23+
}
24+
25+
private void pullModel() {
26+
logger().info("Pulling model: {}. Please be patient, no progress bar yet!", model);
27+
try {
28+
// Construct JSON payload
29+
String json = String.format("{\"from\":\"%s\"}", model);
30+
String endpoint = "http://" + socat.getHost() + ":" + socat.getMappedPort(80) + "/models/create";
31+
32+
URL url = new URL(endpoint);
33+
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
34+
connection.setRequestMethod("POST");
35+
connection.setRequestProperty("Content-Type", "application/json");
36+
connection.setDoOutput(true);
37+
38+
try (OutputStream os = connection.getOutputStream()) {
39+
byte[] input = json.getBytes(StandardCharsets.UTF_8);
40+
os.write(input, 0, input.length);
41+
}
42+
43+
try (BufferedReader br = new BufferedReader(new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
44+
StringBuilder response = new StringBuilder();
45+
String responseLine;
46+
while ((responseLine = br.readLine()) != null) {
47+
response.append(responseLine.trim());
48+
}
49+
logger().info(response.toString());
50+
}
51+
} catch (IOException e) {
52+
throw new RuntimeException(e);
53+
}
54+
logger().info("Finished pulling model: {}", model);
55+
}
56+
57+
@Override
58+
public void stop() {
59+
socat.stop();
60+
}
61+
62+
public String getOpenAIEndpoint() {
63+
return "http://" + socat.getHost() + ":" + socat.getMappedPort(80) + "/engines";
64+
}
65+
66+
public DockerModelRunnerContainer withModel(String modelName) {
67+
this.model = modelName;
68+
return this;
69+
}
70+
71+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package org.testcontainers.containers;
2+
3+
import com.openai.client.OpenAIClient;
4+
import com.openai.client.okhttp.OpenAIOkHttpClient;
5+
import com.openai.models.chat.completions.ChatCompletion;
6+
import com.openai.models.chat.completions.ChatCompletionCreateParams;
7+
import org.junit.Test;
8+
9+
public class DockerModelRunnerContainerTest {
10+
11+
@Test
12+
public void pullsModelAndExposesInference() {
13+
String modelName = "ai/smollm2:360M-Q4_K_M";
14+
15+
try (DockerModelRunnerContainer dmr = new DockerModelRunnerContainer()
16+
.withModel(modelName);) {
17+
dmr.start();
18+
19+
OpenAIClient client = OpenAIOkHttpClient.builder()
20+
.baseUrl(dmr.getOpenAIEndpoint())
21+
.build();
22+
23+
ChatCompletionCreateParams params = ChatCompletionCreateParams.builder()
24+
.addUserMessage("Say this is a test")
25+
.model(modelName)
26+
.build();
27+
ChatCompletion chatCompletion = client.chat().completions().create(params);
28+
29+
String answer = chatCompletion.toString();
30+
System.out.println(answer);
31+
32+
}
33+
}
34+
35+
}

0 commit comments

Comments
 (0)