|
| 1 | +import logging |
| 2 | +import urllib.request |
| 3 | +from urllib.error import URLError |
| 4 | + |
| 5 | +from core.testcontainers.core.container import DockerContainer |
| 6 | +from core.testcontainers.core.waiting_utils import wait_container_is_ready |
| 7 | + |
| 8 | +class JAXContainer(DockerContainer): |
| 9 | + """ |
| 10 | + JAX container for GPU-accelerated numerical computing and machine learning. |
| 11 | +
|
| 12 | + Example: |
| 13 | +
|
| 14 | + .. doctest:: |
| 15 | +
|
| 16 | + >>> import jax |
| 17 | + >>> from testcontainers.jax import JAXContainer |
| 18 | +
|
| 19 | + >>> with JAXContainer("nvcr.io/nvidia/jax:23.08-py3") as jax_container: |
| 20 | + ... # Connect to the container |
| 21 | + ... jax_container.connect() |
| 22 | + ... |
| 23 | + ... # Run a simple JAX computation |
| 24 | + ... result = jax.numpy.add(1, 1) |
| 25 | + ... assert result == 2 |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__(self, image="nvcr.io/nvidia/jax:23.08-py3", **kwargs): |
| 29 | + super().__init__(image, **kwargs) |
| 30 | + self.with_exposed_ports(8888) # Expose Jupyter notebook port |
| 31 | + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") |
| 32 | + self.with_env("CUDA_VISIBLE_DEVICES", "all") |
| 33 | + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support |
| 34 | + |
| 35 | + @wait_container_is_ready(URLError) |
| 36 | + def _connect(self): |
| 37 | + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" |
| 38 | + res = urllib.request.urlopen(url) |
| 39 | + if res.status != 200: |
| 40 | + raise Exception(f"Failed to connect to JAX container. Status: {res.status}") |
| 41 | + |
| 42 | + def connect(self): |
| 43 | + """ |
| 44 | + Connect to the JAX container and ensure it's ready. |
| 45 | + """ |
| 46 | + self._connect() |
| 47 | + logging.info("Successfully connected to JAX container") |
| 48 | + |
| 49 | + def get_jupyter_url(self): |
| 50 | + """ |
| 51 | + Get the URL for accessing the Jupyter notebook server. |
| 52 | + """ |
| 53 | + return f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" |
| 54 | + |
| 55 | + def run_jax_command(self, command): |
| 56 | + """ |
| 57 | + Run a JAX command inside the container. |
| 58 | + """ |
| 59 | + exec_result = self.exec(f"python -c '{command}'") |
| 60 | + return exec_result |
| 61 | + |
| 62 | + def start(self): |
| 63 | + """ |
| 64 | + Start the JAX container. |
| 65 | + """ |
| 66 | + super().start() |
| 67 | + logging.info(f"JAX container started. Jupyter URL: {self.get_jupyter_url()}") |
| 68 | + return self |
0 commit comments