11package org .testcontainers .containers ;
22
3+ import org .testcontainers .containers .wait .strategy .Wait ;
4+
35import java .io .BufferedReader ;
46import java .io .IOException ;
57import java .io .InputStreamReader ;
810import java .net .URL ;
911import java .nio .charset .StandardCharsets ;
1012
11- public class DockerModelRunnerContainer extends GenericContainer {
13+ public class DockerModelRunnerContainer extends GenericContainer < DockerModelRunnerContainer > {
1214
1315 public static final String MODEL_RUNNER_ENDPOINT = "model-runner.docker.internal" ;
16+
1417 private SocatContainer socat ;
18+
1519 private String model ;
1620
1721 @ Override
1822 public void start () {
19- socat = new SocatContainer ()
20- .withTarget (80 , MODEL_RUNNER_ENDPOINT , 80 );
21- socat .start ();
23+ this .socat =
24+ new SocatContainer ()
25+ .withTarget (80 , MODEL_RUNNER_ENDPOINT )
26+ .waitingFor (Wait .forHttp ("/" ).forResponsePredicate (res -> res .contains ("The service is running" )));
27+ this .socat .start ();
2228 pullModel ();
2329 }
2430
2531 private void pullModel () {
26- logger ().info ("Pulling model: {}. Please be patient, no progress bar yet!" , model );
32+ logger ().info ("Pulling model: {}. Please be patient, no progress bar yet!" , this . model );
2733 try {
28- // Construct JSON payload
29- String json = String .format ("{\" from\" :\" %s\" }" , model );
30- String endpoint = "http://" + socat .getHost () + ":" + socat .getMappedPort (80 ) + "/models/create" ;
34+ String json = String .format ("{\" from\" :\" %s\" }" , this .model );
35+ String endpoint = "http://" + this .socat .getHost () + ":" + this .socat .getMappedPort (80 ) + "/models/create" ;
3136
3237 URL url = new URL (endpoint );
3338 HttpURLConnection connection = (HttpURLConnection ) url .openConnection ();
@@ -40,7 +45,11 @@ private void pullModel() {
4045 os .write (input , 0 , input .length );
4146 }
4247
43- try (BufferedReader br = new BufferedReader (new InputStreamReader (connection .getInputStream (), StandardCharsets .UTF_8 ))) {
48+ try (
49+ BufferedReader br = new BufferedReader (
50+ new InputStreamReader (connection .getInputStream (), StandardCharsets .UTF_8 )
51+ )
52+ ) {
4453 StringBuilder response = new StringBuilder ();
4554 String responseLine ;
4655 while ((responseLine = br .readLine ()) != null ) {
@@ -56,16 +65,19 @@ private void pullModel() {
5665
5766 @ Override
5867 public void stop () {
59- socat .stop ();
68+ this .socat .stop ();
69+ }
70+
71+ public String getBaseEndpoint () {
72+ return "http://" + this .socat .getHost () + ":" + this .socat .getMappedPort (80 );
6073 }
6174
6275 public String getOpenAIEndpoint () {
63- return "http://" + socat . getHost () + ":" + socat . getMappedPort ( 80 ) + "/engines" ;
76+ return getBaseEndpoint ( ) + "/engines" ;
6477 }
6578
6679 public DockerModelRunnerContainer withModel (String modelName ) {
6780 this .model = modelName ;
6881 return this ;
6982 }
70-
7183}
0 commit comments