Skip to content

Commit 2541e57

Browse files
committed
add jax testcontainer and whisperjax folder
1 parent c7d51ae commit 2541e57

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import logging
2+
import urllib.request
3+
from urllib.error import URLError
4+
5+
from core.testcontainers.core.container import DockerContainer
6+
from core.testcontainers.core.waiting_utils import wait_container_is_ready
7+
8+
class JAXContainer(DockerContainer):
9+
"""
10+
JAX container for GPU-accelerated numerical computing and machine learning.
11+
12+
Example:
13+
14+
.. doctest::
15+
16+
>>> import jax
17+
>>> from testcontainers.jax import JAXContainer
18+
19+
>>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container:
20+
... # Connect to the container
21+
... jax_container.connect()
22+
...
23+
... # Run a simple JAX computation
24+
... result = jax.numpy.add(1, 1)
25+
... assert result == 2
26+
"""
27+
28+
def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs):
29+
super().__init__(image, **kwargs)
30+
self.with_exposed_ports(8888) # Expose Jupyter notebook port
31+
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
32+
self.with_env("CUDA_VISIBLE_DEVICES", "all")
33+
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
34+
35+
@wait_container_is_ready(URLError)
36+
def _connect(self):
37+
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
38+
res = urllib.request.urlopen(url)
39+
if res.status != 200:
40+
raise Exception(f"Failed to connect to JAX container. Status: {res.status}")
41+
42+
def connect(self):
43+
"""
44+
Connect to the JAX container and ensure it's ready.
45+
"""
46+
self._connect()
47+
logging.info("Successfully connected to JAX container")
48+
49+
def get_jupyter_url(self):
50+
"""
51+
Get the URL for accessing the Jupyter notebook server.
52+
"""
53+
return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
54+
55+
def run_jax_command(self, command):
56+
"""
57+
Run a JAX command inside the container.
58+
"""
59+
exec_result = self.exec(f"python -c '{command}'")
60+
return exec_result
61+
62+
def start(self):
63+
"""
64+
Start the JAX container.
65+
"""
66+
super().start()
67+
logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}")
68+
return self

0 commit comments

Comments
 (0)