Skip to content

Commit 9f5d6b8

Browse files
committed
add readme draft
1 parent 66d5460 commit 9f5d6b8

File tree

3 files changed

+38
-126
lines changed

3 files changed

+38
-126
lines changed

modules/jax/README.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Testcontainers : JAX
2+
3+
## Docker Containers for JAX with GPU Support
4+
5+
1. **Official JAX Docker Container**
6+
- **Container**: `jax/jax:cuda-12.0`
7+
- **Documentation**: [JAX Docker](https://github.com/google/jax/blob/main/docker/README.md)
8+
9+
2. **NVIDIA Docker Container**
10+
- **Container**: `nvidia/cuda:12.0-cudnn8-devel-ubuntu20.04`
11+
- **Documentation**: [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda)
12+
13+
## Benefits of Having This Container
14+
15+
1. **Optimized Performance**: JAX uses XLA to compile and run NumPy programs on GPUs, which can significantly speed up numerical computations and machine learning tasks. A container specifically optimized for JAX with CUDA ensures that the environment is configured to leverage GPU acceleration fully.
16+
17+
2. **Reproducibility**: Containers encapsulate all dependencies, libraries, and configurations needed to run JAX, ensuring that the environment is consistent across different systems. This is crucial for reproducible research and development.
18+
19+
3. **Ease of Use**: Users can easily pull and run the container without worrying about the complex setup required for GPU support and JAX configuration. This reduces the barrier to entry for new users and accelerates development workflows.
20+
21+
4. **Isolation and Security**: Containers provide an isolated environment, which enhances security by limiting the impact of potential vulnerabilities. It also avoids conflicts with other software on the host system.
22+
23+
## Relevant Reading Material
24+
25+
1. **JAX Documentation**
26+
- [JAX Quickstart](https://github.com/google/jax#quickstart)
27+
- [JAX Transformations](https://github.com/google/jax#transformations)
28+
- [JAX Installation Guide](https://github.com/google/jax#installation)
29+
30+
2. **NVIDIA Docker Documentation**
31+
- [NVIDIA Docker Hub](https://hub.docker.com/r/nvidia/cuda)
32+
- [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker)
33+
34+
3. **Docker Best Practices**
35+
- [Docker Documentation](https://docs.docker.com/get-started/)
36+
- [Best practices for writing Dockerfiles](https://docs.docker.com/develop/develop-images/dockerfile_best-practices/)
Lines changed: 1 addition & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,126 +1 @@
1-
import logging
2-
import tempfile
3-
import time
4-
from typing import Optional
5-
6-
from core.testcontainers.core.container import DockerContainer
7-
from core.testcontainers.core.waiting_utils import wait_container_is_ready
8-
from urllib.error import URLError
9-
10-
class WhisperJAXContainer(DockerContainer):
11-
"""
12-
Whisper-JAX container for fast speech recognition and transcription.
13-
14-
Example:
15-
16-
.. doctest::
17-
18-
>>> from testcontainers.whisper_jax import WhisperJAXContainer
19-
20-
>>> with WhisperJAXContainer("openai/whisper-large-v2") as whisper:
21-
... # Connect to the container
22-
... whisper.connect()
23-
...
24-
... # Transcribe an audio file
25-
... result = whisper.transcribe_file("path/to/audio/file.wav")
26-
... print(result['text'])
27-
...
28-
... # Transcribe a YouTube video
29-
... result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
30-
... print(result['text'])
31-
"""
32-
33-
def __init__(self, model_name: str = "openai/whisper-large-v2", **kwargs):
34-
super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs)
35-
self.model_name = model_name
36-
self.with_exposed_ports(8888) # Expose Jupyter notebook port
37-
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
38-
self.with_env("CUDA_VISIBLE_DEVICES", "all")
39-
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
40-
41-
# Install required dependencies
42-
self.with_command("sh -c '"
43-
"pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && "
44-
"pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && "
45-
"python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && "
46-
"jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
47-
"'")
48-
49-
@wait_container_is_ready(URLError)
50-
def _connect(self):
51-
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
52-
res = urllib.request.urlopen(url)
53-
if res.status != 200:
54-
raise Exception(f"Failed to connect to Whisper-JAX container. Status: {res.status}")
55-
56-
def connect(self):
57-
"""
58-
Connect to the Whisper-JAX container and ensure it's ready.
59-
"""
60-
self._connect()
61-
logging.info("Successfully connected to Whisper-JAX container")
62-
63-
def run_command(self, command: str):
64-
"""
65-
Run a Python command inside the container.
66-
"""
67-
exec_result = self.exec(f"python -c '{command}'")
68-
return exec_result
69-
70-
def transcribe_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = False):
71-
"""
72-
Transcribe an audio file using Whisper-JAX.
73-
"""
74-
command = f"""
75-
import soundfile as sf
76-
from whisper_jax import FlaxWhisperPipline
77-
import jax.numpy as jnp
78-
79-
pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16)
80-
audio, sr = sf.read("{file_path}")
81-
result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps})
82-
print(result)
83-
"""
84-
return self.run_command(command)
85-
86-
def transcribe_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = False):
87-
"""
88-
Transcribe a YouTube video using Whisper-JAX.
89-
"""
90-
command = f"""
91-
import tempfile
92-
import youtube_dl
93-
import soundfile as sf
94-
from whisper_jax import FlaxWhisperPipline
95-
import jax.numpy as jnp
96-
97-
def download_youtube_audio(youtube_url, output_file):
98-
ydl_opts = {{
99-
'format': 'bestaudio/best',
100-
'postprocessors': [{{
101-
'key': 'FFmpegExtractAudio',
102-
'preferredcodec': 'wav',
103-
'preferredquality': '192',
104-
}}],
105-
'outtmpl': output_file,
106-
}}
107-
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
108-
ydl.download([youtube_url])
109-
110-
pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16)
111-
112-
with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file:
113-
download_youtube_audio("{youtube_url}", temp_file.name)
114-
audio, sr = sf.read(temp_file.name)
115-
result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps})
116-
print(result)
117-
"""
118-
return self.run_command(command)
119-
120-
def start(self):
121-
"""
122-
Start the Whisper-JAX container.
123-
"""
124-
super().start()
125-
logging.info(f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}")
126-
return self
1+

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ packages = [
4141
{ include = "testcontainers", from = "modules/test_module_import"},
4242
{ include = "testcontainers", from = "modules/google" },
4343
{ include = "testcontainers", from = "modules/influxdb" },
44+
{ include = "testcontainers", from = "modules/jax" },
4445
{ include = "testcontainers", from = "modules/k3s" },
4546
{ include = "testcontainers", from = "modules/kafka" },
4647
{ include = "testcontainers", from = "modules/keycloak" },

0 commit comments

Comments
 (0)