Skip to content

Commit 6df84ad

Browse files
committed
Polish
1 parent 861f324 commit 6df84ad

File tree

3 files changed

+33
-34
lines changed

3 files changed

+33
-34
lines changed

core/build.gradle

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,6 @@ dependencies {
123123
testImplementation 'org.assertj:assertj-core:3.26.3'
124124
testImplementation 'io.rest-assured:rest-assured:5.5.0'
125125

126-
// DockerModelRunnerContainer tests
127-
testImplementation("com.openai:openai-java:1.3.0")
128-
129126
jarFileTestCompileOnly "org.projectlombok:lombok:${lombok.version}"
130127
jarFileTestAnnotationProcessor "org.projectlombok:lombok:${lombok.version}"
131128
jarFileTestImplementation 'junit:junit:4.13.2'

core/src/main/java/org/testcontainers/containers/DockerModelRunnerContainer.java

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package org.testcontainers.containers;
22

3+
import org.testcontainers.containers.wait.strategy.Wait;
4+
35
import java.io.BufferedReader;
46
import java.io.IOException;
57
import java.io.InputStreamReader;
@@ -8,26 +10,29 @@
810
import java.net.URL;
911
import java.nio.charset.StandardCharsets;
1012

11-
public class DockerModelRunnerContainer extends GenericContainer {
13+
public class DockerModelRunnerContainer extends GenericContainer<DockerModelRunnerContainer> {
1214

1315
public static final String MODEL_RUNNER_ENDPOINT = "model-runner.docker.internal";
16+
1417
private SocatContainer socat;
18+
1519
private String model;
1620

1721
@Override
1822
public void start() {
19-
socat = new SocatContainer()
20-
.withTarget(80, MODEL_RUNNER_ENDPOINT, 80);
21-
socat.start();
23+
this.socat =
24+
new SocatContainer()
25+
.withTarget(80, MODEL_RUNNER_ENDPOINT)
26+
.waitingFor(Wait.forHttp("/").forResponsePredicate(res -> res.contains("The service is running")));
27+
this.socat.start();
2228
pullModel();
2329
}
2430

2531
private void pullModel() {
26-
logger().info("Pulling model: {}. Please be patient, no progress bar yet!", model);
32+
logger().info("Pulling model: {}. Please be patient, no progress bar yet!", this.model);
2733
try {
28-
// Construct JSON payload
29-
String json = String.format("{\"from\":\"%s\"}", model);
30-
String endpoint = "http://" + socat.getHost() + ":" + socat.getMappedPort(80) + "/models/create";
34+
String json = String.format("{\"from\":\"%s\"}", this.model);
35+
String endpoint = "http://" + this.socat.getHost() + ":" + this.socat.getMappedPort(80) + "/models/create";
3136

3237
URL url = new URL(endpoint);
3338
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
@@ -40,7 +45,11 @@ private void pullModel() {
4045
os.write(input, 0, input.length);
4146
}
4247

43-
try (BufferedReader br = new BufferedReader(new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
48+
try (
49+
BufferedReader br = new BufferedReader(
50+
new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8)
51+
)
52+
) {
4453
StringBuilder response = new StringBuilder();
4554
String responseLine;
4655
while ((responseLine = br.readLine()) != null) {
@@ -56,16 +65,19 @@ private void pullModel() {
5665

5766
@Override
5867
public void stop() {
59-
socat.stop();
68+
this.socat.stop();
69+
}
70+
71+
public String getBaseEndpoint() {
72+
return "http://" + this.socat.getHost() + ":" + this.socat.getMappedPort(80);
6073
}
6174

6275
public String getOpenAIEndpoint() {
63-
return "http://" + socat.getHost() + ":" + socat.getMappedPort(80) + "/engines";
76+
return getBaseEndpoint() + "/engines";
6477
}
6578

6679
public DockerModelRunnerContainer withModel(String modelName) {
6780
this.model = modelName;
6881
return this;
6982
}
70-
7183
}
Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,25 @@
11
package org.testcontainers.containers;
22

3-
import com.openai.client.OpenAIClient;
4-
import com.openai.client.okhttp.OpenAIOkHttpClient;
5-
import com.openai.models.chat.completions.ChatCompletion;
6-
import com.openai.models.chat.completions.ChatCompletionCreateParams;
3+
import io.restassured.RestAssured;
4+
import io.restassured.response.Response;
75
import org.junit.Test;
86

7+
import static org.assertj.core.api.Assertions.assertThat;
8+
99
public class DockerModelRunnerContainerTest {
1010

1111
@Test
1212
public void pullsModelAndExposesInference() {
1313
String modelName = "ai/smollm2:360M-Q4_K_M";
1414

15-
try (DockerModelRunnerContainer dmr = new DockerModelRunnerContainer()
16-
.withModel(modelName);) {
15+
try (DockerModelRunnerContainer dmr = new DockerModelRunnerContainer().withModel(modelName)) {
1716
dmr.start();
1817

19-
OpenAIClient client = OpenAIOkHttpClient.builder()
20-
.baseUrl(dmr.getOpenAIEndpoint())
21-
.build();
22-
23-
ChatCompletionCreateParams params = ChatCompletionCreateParams.builder()
24-
.addUserMessage("Say this is a test")
25-
.model(modelName)
26-
.build();
27-
ChatCompletion chatCompletion = client.chat().completions().create(params);
28-
29-
String answer = chatCompletion.toString();
30-
System.out.println(answer);
18+
Response response = RestAssured.get(dmr.getBaseEndpoint() + "/models").thenReturn();
19+
assertThat(response.body().jsonPath().getList("tags.flatten()")).contains(modelName);
3120

21+
Response openAiResponse = RestAssured.get(dmr.getOpenAIEndpoint() + "/v1/models").prettyPeek().thenReturn();
22+
assertThat(openAiResponse.body().jsonPath().getList("data.id")).contains(modelName);
3223
}
3324
}
34-
3525
}

0 commit comments

Comments
 (0)