Skip to content

Commit 5442d05

Browse files
feat(core): Add support for ollama module (#618)
- Added a new class OllamaContainer with few methods to handle the Ollama container. - The `_check_and_add_gpu_capabilities` method checks if the host has GPUs and adds the necessary capabilities to the container. - The `commit_to_image` allows to save somehow the state of a container into an image so that we can reuse it, especially for the ones having some models pulled. - Added tests to check the functionality of the new class. > Note: I inspired myself from the java implementation of the Ollama module. Fixes #617 --------- Co-authored-by: David Ankin <[email protected]>
1 parent ead0f79 commit 5442d05

File tree

5 files changed

+188
-1
lines changed

5 files changed

+188
-1
lines changed

modules/ollama/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.. autoclass:: testcontainers.ollama.OllamaContainer
2+
.. title:: testcontainers.ollama.OllamaContainer
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
3+
# not use this file except in compliance with the License. You may obtain
4+
# a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
# License for the specific language governing permissions and limitations
12+
# under the License.
13+
14+
from os import PathLike
15+
from typing import Any, Optional, TypedDict, Union
16+
17+
from docker.types.containers import DeviceRequest
18+
from requests import get
19+
20+
from testcontainers.core.container import DockerContainer
21+
from testcontainers.core.waiting_utils import wait_for_logs
22+
23+
24+
class OllamaModel(TypedDict):
25+
name: str
26+
model: str
27+
modified_at: str
28+
size: int
29+
digest: str
30+
details: dict[str, Any]
31+
32+
33+
class OllamaContainer(DockerContainer):
34+
"""
35+
Ollama Container
36+
37+
Example:
38+
39+
.. doctest::
40+
41+
>>> from testcontainers.ollama import OllamaContainer
42+
>>> with OllamaContainer() as ollama:
43+
... ollama.list_models()
44+
[]
45+
"""
46+
47+
OLLAMA_PORT = 11434
48+
49+
def __init__(
50+
self,
51+
image: str = "ollama/ollama:0.1.44",
52+
ollama_dir: Optional[Union[str, PathLike]] = None,
53+
**kwargs,
54+
#
55+
):
56+
super().__init__(image=image, **kwargs)
57+
self.ollama_dir = ollama_dir
58+
self.with_exposed_ports(OllamaContainer.OLLAMA_PORT)
59+
self._check_and_add_gpu_capabilities()
60+
61+
def _check_and_add_gpu_capabilities(self):
62+
info = self.get_docker_client().client.info()
63+
if "nvidia" in info["Runtimes"]:
64+
self._kwargs = {**self._kwargs, "device_requests": DeviceRequest(count=-1, capabilities=[["gpu"]])}
65+
66+
def start(self) -> "OllamaContainer":
67+
"""
68+
Start the Ollama server
69+
"""
70+
if self.ollama_dir:
71+
self.with_volume_mapping(self.ollama_dir, "/root/.ollama", "rw")
72+
super().start()
73+
wait_for_logs(self, "Listening on ", timeout=30)
74+
75+
return self
76+
77+
def get_endpoint(self):
78+
"""
79+
Return the endpoint of the Ollama server
80+
"""
81+
host = self.get_container_host_ip()
82+
exposed_port = self.get_exposed_port(OllamaContainer.OLLAMA_PORT)
83+
url = f"http://{host}:{exposed_port}"
84+
return url
85+
86+
@property
87+
def id(self) -> str:
88+
"""
89+
Return the container object
90+
"""
91+
return self._container.id
92+
93+
def pull_model(self, model_name: str) -> None:
94+
"""
95+
Pull a model from the Ollama server
96+
97+
Args:
98+
model_name (str): Name of the model
99+
"""
100+
self.exec(f"ollama pull {model_name}")
101+
102+
def list_models(self) -> list[OllamaModel]:
103+
endpoint = self.get_endpoint()
104+
response = get(url=f"{endpoint}/api/tags")
105+
response.raise_for_status()
106+
return response.json().get("models", [])
107+
108+
def commit_to_image(self, image_name: str) -> None:
109+
"""
110+
Commit the current container to a new image
111+
112+
Args:
113+
image_name (str): Name of the new image
114+
"""
115+
docker_client = self.get_docker_client()
116+
existing_images = docker_client.client.images.list(name=image_name)
117+
if not existing_images and self.id:
118+
docker_client.client.containers.get(self.id).commit(
119+
repository=image_name, conf={"Labels": {"org.testcontainers.session-id": ""}}
120+
)

modules/ollama/tests/test_ollama.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import random
2+
import string
3+
from pathlib import Path
4+
5+
import requests
6+
from testcontainers.ollama import OllamaContainer
7+
8+
9+
def random_string(length=6):
10+
return "".join(random.choices(string.ascii_lowercase, k=length))
11+
12+
13+
def test_ollama_container():
14+
with OllamaContainer() as ollama:
15+
url = ollama.get_endpoint()
16+
response = requests.get(url)
17+
assert response.status_code == 200
18+
assert response.text == "Ollama is running"
19+
20+
21+
def test_with_default_config():
22+
with OllamaContainer("ollama/ollama:0.1.26") as ollama:
23+
ollama.start()
24+
response = requests.get(f"{ollama.get_endpoint()}/api/version")
25+
version = response.json().get("version")
26+
assert version == "0.1.26"
27+
28+
29+
def test_download_model_and_commit_to_image():
30+
new_image_name = f"tc-ollama-allminilm-{random_string(length=4).lower()}"
31+
with OllamaContainer("ollama/ollama:0.1.26") as ollama:
32+
ollama.start()
33+
# Pull the model
34+
ollama.pull_model("all-minilm")
35+
36+
response = requests.get(f"{ollama.get_endpoint()}/api/tags")
37+
model_name = ollama.list_models()[0].get("name")
38+
assert "all-minilm" in model_name
39+
40+
# Commit the container state to a new image
41+
ollama.commit_to_image(new_image_name)
42+
43+
# Verify the new image
44+
with OllamaContainer(new_image_name) as ollama:
45+
ollama.start()
46+
response = requests.get(f"{ollama.get_endpoint()}/api/tags")
47+
model_name = response.json().get("models", [])[0].get("name")
48+
assert "all-minilm" in model_name
49+
50+
51+
def test_models_saved_in_folder(tmp_path: Path):
52+
with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama:
53+
assert len(ollama.list_models()) == 0
54+
ollama.pull_model("all-minilm")
55+
assert len(ollama.list_models()) == 1
56+
assert "all-minilm" in ollama.list_models()[0].get("name")
57+
58+
with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama:
59+
assert len(ollama.list_models()) == 1
60+
assert "all-minilm" in ollama.list_models()[0].get("name")

poetry.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ packages = [
5151
{ include = "testcontainers", from = "modules/nats" },
5252
{ include = "testcontainers", from = "modules/neo4j" },
5353
{ include = "testcontainers", from = "modules/nginx" },
54+
{ include = "testcontainers", from = "modules/ollama" },
5455
{ include = "testcontainers", from = "modules/opensearch" },
5556
{ include = "testcontainers", from = "modules/oracle-free" },
5657
{ include = "testcontainers", from = "modules/postgres" },
@@ -127,6 +128,7 @@ nats = ["nats-py"]
127128
neo4j = ["neo4j"]
128129
nginx = []
129130
opensearch = ["opensearch-py"]
131+
ollama = []
130132
oracle = ["sqlalchemy", "oracledb"]
131133
oracle-free = ["sqlalchemy", "oracledb"]
132134
postgres = []
@@ -272,6 +274,7 @@ mypy_path = [
272274
# "modules/mysql",
273275
# "modules/neo4j",
274276
# "modules/nginx",
277+
# "modules/ollama",
275278
# "modules/opensearch",
276279
# "modules/oracle",
277280
# "modules/postgres",

0 commit comments

Comments
 (0)