Skip to content

Commit 03dd3b8

Browse files
committed
add whispe-jax
1 parent 2541e57 commit 03dd3b8

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import logging
2+
import tempfile
3+
import time
4+
from typing import Optional
5+
6+
from core.testcontainers.core.container import DockerContainer
7+
from core.testcontainers.core.waiting_utils import wait_container_is_ready
8+
from urllib.error import URLError
9+
10+
class WhisperJAXContainer(DockerContainer):
11+
"""
12+
Whisper-JAX container for fast speech recognition and transcription.
13+
14+
Example:
15+
16+
.. doctest::
17+
18+
>>> from testcontainers.whisper_jax import WhisperJAXContainer
19+
20+
>>> with WhisperJAXContainer("openai/whisper-large-v2") as whisper:
21+
... # Connect to the container
22+
... whisper.connect()
23+
...
24+
... # Transcribe an audio file
25+
... result = whisper.transcribe_file("path/to/audio/file.wav")
26+
... print(result['text'])
27+
...
28+
... # Transcribe a YouTube video
29+
... result = whisper.transcribe_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
30+
... print(result['text'])
31+
"""
32+
33+
def __init__(self, model_name: str = "openai/whisper-large-v2", **kwargs):
34+
super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs)
35+
self.model_name = model_name
36+
self.with_exposed_ports(8888) # Expose Jupyter notebook port
37+
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
38+
self.with_env("CUDA_VISIBLE_DEVICES", "all")
39+
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
40+
41+
# Install required dependencies
42+
self.with_command("sh -c '"
43+
"pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && "
44+
"pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets && "
45+
"python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && "
46+
"jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
47+
"'")
48+
49+
@wait_container_is_ready(URLError)
50+
def _connect(self):
51+
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
52+
res = urllib.request.urlopen(url)
53+
if res.status != 200:
54+
raise Exception(f"Failed to connect to Whisper-JAX container. Status: {res.status}")
55+
56+
def connect(self):
57+
"""
58+
Connect to the Whisper-JAX container and ensure it's ready.
59+
"""
60+
self._connect()
61+
logging.info("Successfully connected to Whisper-JAX container")
62+
63+
def run_command(self, command: str):
64+
"""
65+
Run a Python command inside the container.
66+
"""
67+
exec_result = self.exec(f"python -c '{command}'")
68+
return exec_result
69+
70+
def transcribe_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = False):
71+
"""
72+
Transcribe an audio file using Whisper-JAX.
73+
"""
74+
command = f"""
75+
import soundfile as sf
76+
from whisper_jax import FlaxWhisperPipline
77+
import jax.numpy as jnp
78+
79+
pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16)
80+
audio, sr = sf.read("{file_path}")
81+
result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps})
82+
print(result)
83+
"""
84+
return self.run_command(command)
85+
86+
def transcribe_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = False):
87+
"""
88+
Transcribe a YouTube video using Whisper-JAX.
89+
"""
90+
command = f"""
91+
import tempfile
92+
import youtube_dl
93+
import soundfile as sf
94+
from whisper_jax import FlaxWhisperPipline
95+
import jax.numpy as jnp
96+
97+
def download_youtube_audio(youtube_url, output_file):
98+
ydl_opts = {{
99+
'format': 'bestaudio/best',
100+
'postprocessors': [{{
101+
'key': 'FFmpegExtractAudio',
102+
'preferredcodec': 'wav',
103+
'preferredquality': '192',
104+
}}],
105+
'outtmpl': output_file,
106+
}}
107+
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
108+
ydl.download([youtube_url])
109+
110+
pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16)
111+
112+
with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file:
113+
download_youtube_audio("{youtube_url}", temp_file.name)
114+
audio, sr = sf.read(temp_file.name)
115+
result = pipeline({{"array": audio, "sampling_rate": sr}}, task="{task}", return_timestamps={return_timestamps})
116+
print(result)
117+
"""
118+
return self.run_command(command)
119+
120+
def start(self):
121+
"""
122+
Start the Whisper-JAX container.
123+
"""
124+
super().start()
125+
logging.info(f"Whisper-JAX container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}")
126+
return self

0 commit comments

Comments
 (0)