1
1
package org .testcontainers .containers ;
2
2
3
+ import com .github .dockerjava .api .command .InspectContainerResponse ;
4
+ import lombok .extern .slf4j .Slf4j ;
3
5
import org .testcontainers .containers .wait .strategy .Wait ;
4
6
import org .testcontainers .utility .DockerImageName ;
5
7
8
+ import java .io .BufferedReader ;
9
+ import java .io .IOException ;
10
+ import java .io .InputStreamReader ;
11
+ import java .io .OutputStream ;
12
+ import java .net .HttpURLConnection ;
13
+ import java .net .URL ;
14
+ import java .nio .charset .StandardCharsets ;
15
+
6
16
/**
7
17
* Testcontainers proxy container for the Docker Model Runner service
8
18
* provided by Docker Desktop.
11
21
* <p>
12
22
* Exposed ports: 80
13
23
*/
24
+ @ Slf4j
14
25
public class DockerModelRunnerContainer extends SocatContainer {
15
26
16
27
private static final String MODEL_RUNNER_ENDPOINT = "model-runner.docker.internal" ;
17
28
18
29
private static final int PORT = 80 ;
19
30
31
+ private String model ;
32
+
20
33
public DockerModelRunnerContainer (String image ) {
21
34
this (DockerImageName .parse (image ));
22
35
}
@@ -27,6 +40,45 @@ public DockerModelRunnerContainer(DockerImageName image) {
27
40
waitingFor (Wait .forHttp ("/" ).forResponsePredicate (res -> res .contains ("The service is running" )));
28
41
}
29
42
43
+ @ Override
44
+ protected void containerIsStarted (InspectContainerResponse containerInfo ) {
45
+ if (this .model != null ) {
46
+ logger ().info ("Pulling model: {}. Please be patient." , this .model );
47
+
48
+ String url = getBaseEndpoint () + "/models/create" ;
49
+ String payload = String .format ("{\" from\" : \" %s\" }" , this .model );
50
+
51
+ try {
52
+ HttpURLConnection connection = (HttpURLConnection ) new URL (url ).openConnection ();
53
+ connection .setRequestMethod ("POST" );
54
+ connection .setRequestProperty ("Content-Type" , "application/json" );
55
+ connection .setDoOutput (true );
56
+
57
+ try (OutputStream os = connection .getOutputStream ()) {
58
+ os .write (payload .getBytes ());
59
+ os .flush ();
60
+ }
61
+
62
+ try (
63
+ BufferedReader br = new BufferedReader (
64
+ new InputStreamReader (connection .getInputStream (), StandardCharsets .UTF_8 )
65
+ )
66
+ ) {
67
+ while (br .readLine () != null ) {}
68
+ }
69
+ connection .disconnect ();
70
+ } catch (IOException e ) {
71
+ logger ().error ("Failed to pull model {}: {}" , this .model , e );
72
+ }
73
+ logger ().info ("Finished pulling model: {}" , this .model );
74
+ }
75
+ }
76
+
77
+ public DockerModelRunnerContainer withModel (String model ) {
78
+ this .model = model ;
79
+ return this ;
80
+ }
81
+
30
82
public String getBaseEndpoint () {
31
83
return "http://" + getHost () + ":" + getMappedPort (PORT );
32
84
}
0 commit comments