Skip to content

Commit 62a383b

Browse files
authored
Add example to run Hugging Face models using Ollama (#8771)
1 parent ca4981e commit 62a383b

File tree

5 files changed

+180
-0
lines changed

5 files changed

+180
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
plugins {
2+
id 'java'
3+
}
4+
5+
repositories {
6+
mavenCentral()
7+
}
8+
9+
dependencies {
10+
testImplementation 'org.testcontainers:ollama'
11+
testImplementation 'org.assertj:assertj-core:3.25.3'
12+
testImplementation 'ch.qos.logback:logback-classic:1.3.14'
13+
testImplementation 'org.junit.jupiter:junit-jupiter:5.10.2'
14+
testImplementation 'io.rest-assured:rest-assured:5.4.0'
15+
}
16+
17+
test {
18+
useJUnitPlatform()
19+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package com.example.ollamahf;
2+
3+
import com.github.dockerjava.api.command.InspectContainerResponse;
4+
import org.testcontainers.containers.ContainerLaunchException;
5+
import org.testcontainers.ollama.OllamaContainer;
6+
import org.testcontainers.utility.DockerImageName;
7+
8+
import java.io.IOException;
9+
10+
public class OllamaHuggingFaceContainer extends OllamaContainer {
11+
12+
private final HuggingFaceModel huggingFaceModel;
13+
14+
public OllamaHuggingFaceContainer(HuggingFaceModel model) {
15+
super(DockerImageName.parse("ollama/ollama:0.1.47"));
16+
this.huggingFaceModel = model;
17+
}
18+
19+
@Override
20+
protected void containerIsStarted(InspectContainerResponse containerInfo, boolean reused) {
21+
super.containerIsStarted(containerInfo, reused);
22+
if (reused || huggingFaceModel == null) {
23+
return;
24+
}
25+
26+
try {
27+
executeCommand("apt-get", "update");
28+
executeCommand("apt-get", "upgrade", "-y");
29+
executeCommand("apt-get", "install", "-y", "python3-pip");
30+
executeCommand("pip", "install", "huggingface-hub");
31+
executeCommand(
32+
"huggingface-cli",
33+
"download",
34+
huggingFaceModel.repository,
35+
huggingFaceModel.model,
36+
"--local-dir",
37+
"."
38+
);
39+
executeCommand("sh", "-c", String.format("echo '%s' > Modelfile", huggingFaceModel.modelfileContent));
40+
executeCommand("ollama", "create", huggingFaceModel.model, "-f", "Modelfile");
41+
executeCommand("rm", huggingFaceModel.model);
42+
} catch (IOException | InterruptedException e) {
43+
throw new ContainerLaunchException(e.getMessage());
44+
}
45+
}
46+
47+
private void executeCommand(String... command) throws ContainerLaunchException, IOException, InterruptedException {
48+
ExecResult execResult = execInContainer(command);
49+
if (execResult.getExitCode() > 0) {
50+
throw new ContainerLaunchException(
51+
"Failed to execute " + String.join(" ", command) + ": " + execResult.getStderr()
52+
);
53+
}
54+
}
55+
56+
public static class HuggingFaceModel {
57+
58+
public final String repository;
59+
60+
public final String model;
61+
62+
public String modelfileContent;
63+
64+
public HuggingFaceModel(String repository, String model) {
65+
this.repository = repository;
66+
this.model = model;
67+
this.modelfileContent = "FROM " + model;
68+
}
69+
}
70+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package com.example.ollamahf;
2+
3+
import io.restassured.http.Header;
4+
import org.junit.jupiter.api.Test;
5+
import org.testcontainers.ollama.OllamaContainer;
6+
import org.testcontainers.utility.DockerImageName;
7+
8+
import java.util.List;
9+
10+
import static io.restassured.RestAssured.given;
11+
import static org.assertj.core.api.Assertions.assertThat;
12+
13+
public class OllamaHuggingFaceTest {
14+
15+
@Test
16+
public void embeddingModelWithHuggingFace() {
17+
String repository = "CompendiumLabs/bge-small-en-v1.5-gguf";
18+
String model = "bge-small-en-v1.5-q4_k_m.gguf";
19+
String imageName = "embedding-model-from-hugging-face";
20+
OllamaContainer ollama = new OllamaContainer(
21+
DockerImageName.parse(imageName).asCompatibleSubstituteFor("ollama/ollama:0.1.47")
22+
);
23+
boolean imageExists = ollama
24+
.getDockerClient()
25+
.listImagesCmd()
26+
.exec()
27+
.stream()
28+
.anyMatch(image -> image.getRepoTags()[0].equals(imageName + ":latest"));
29+
if (!imageExists) {
30+
createImage(imageName, repository, model);
31+
}
32+
ollama.start();
33+
34+
String modelName = given()
35+
.baseUri(ollama.getEndpoint())
36+
.get("/api/tags")
37+
.jsonPath()
38+
.getString("models[0].name");
39+
assertThat(modelName).contains(model + ":latest");
40+
41+
List<Float> embedding = given()
42+
.baseUri(ollama.getEndpoint())
43+
.header(new Header("Content-Type", "application/json"))
44+
.body(new EmbeddingRequest(model + ":latest", "Hello from Testcontainers!"))
45+
.post("/api/embeddings")
46+
.jsonPath()
47+
.getList("embedding");
48+
49+
assertThat(embedding).isNotNull();
50+
assertThat(embedding.isEmpty()).isFalse();
51+
}
52+
53+
private static void createImage(String imageName, String repository, String model) {
54+
OllamaHuggingFaceContainer.HuggingFaceModel hfModel = new OllamaHuggingFaceContainer.HuggingFaceModel(
55+
repository,
56+
model
57+
);
58+
OllamaHuggingFaceContainer huggingFaceContainer = new OllamaHuggingFaceContainer(hfModel);
59+
huggingFaceContainer.start();
60+
huggingFaceContainer.commitToImage(imageName);
61+
}
62+
63+
public static class EmbeddingRequest {
64+
65+
public final String model;
66+
67+
public final String prompt;
68+
69+
public EmbeddingRequest(String model, String prompt) {
70+
this.model = model;
71+
this.prompt = prompt;
72+
}
73+
}
74+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<configuration>
2+
3+
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
4+
<!-- encoders are assigned the type
5+
ch.qos.logback.classic.encoder.PatternLayoutEncoder by default -->
6+
<encoder>
7+
<pattern>%d{HH:mm:ss.SSS} %-5level %logger - %msg%n</pattern>
8+
</encoder>
9+
</appender>
10+
11+
<root level="INFO">
12+
<appender-ref ref="STDOUT"/>
13+
</root>
14+
15+
<logger name="org.testcontainers" level="INFO"/>
16+
</configuration>

examples/settings.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ include 'zookeeper'
3535
include 'hazelcast'
3636
include 'nats'
3737
include 'sftp'
38+
include 'ollama-hugging-face'
3839

3940
ext.isCI = System.getenv("CI") != null
4041

0 commit comments

Comments
 (0)