Skip to content

Commit e6448e4

Browse files
committed
add diarization
1 parent 03dd3b8 commit e6448e4

File tree

1 file changed

+249
-0
lines changed
  • modules/jax/testcontainers/whisper-diarization

1 file changed

+249
-0
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import logging
2+
import tempfile
3+
from typing import Optional
4+
5+
from testcontainers.core.container import DockerContainer
6+
from testcontainers.core.waiting_utils import wait_container_is_ready
7+
from urllib.error import URLError
8+
9+
class JAXWhisperDiarizationContainer(DockerContainer):
10+
"""
11+
JAX-Whisper-Diarization container for fast speech recognition, transcription, and speaker diarization.
12+
13+
Example:
14+
15+
.. doctest::
16+
17+
>>> logging.basicConfig(level=logging.INFO)
18+
19+
... # You need to provide your Hugging Face token to use the pyannote.audio models
20+
>>> hf_token = "your_huggingface_token_here"
21+
22+
>>> with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization:
23+
... whisper_diarization.connect()
24+
...
25+
... # Example: Transcribe and diarize an audio file
26+
... result = whisper_diarization.transcribe_and_diarize_file("/path/to/audio/file.wav")
27+
... print(f"Transcription and Diarization: {result}")
28+
...
29+
... # Example: Transcribe and diarize a YouTube video
30+
... result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
31+
... print(f"YouTube Transcription and Diarization: {result}")
32+
"""
33+
34+
def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Optional[str] = None, **kwargs):
35+
super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs)
36+
self.model_name = model_name
37+
self.hf_token = hf_token
38+
self.with_exposed_ports(8888) # Expose Jupyter notebook port
39+
self.with_env("NVIDIA_VISIBLE_DEVICES", "all")
40+
self.with_env("CUDA_VISIBLE_DEVICES", "all")
41+
self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support
42+
43+
# Install required dependencies
44+
self.with_command("sh -c '"
45+
"pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && "
46+
"pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && "
47+
"python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && "
48+
"jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
49+
"'")
50+
51+
@wait_container_is_ready(URLError)
52+
def _connect(self):
53+
url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}"
54+
res = urllib.request.urlopen(url)
55+
if res.status != 200:
56+
raise Exception(f"Failed to connect to JAX-Whisper-Diarization container. Status: {res.status}")
57+
58+
def connect(self):
59+
"""
60+
Connect to the JAX-Whisper-Diarization container and ensure it's ready.
61+
"""
62+
self._connect()
63+
logging.info("Successfully connected to JAX-Whisper-Diarization container")
64+
65+
def run_command(self, command: str):
66+
"""
67+
Run a Python command inside the container.
68+
"""
69+
exec_result = self.exec(f"python -c '{command}'")
70+
return exec_result
71+
72+
def transcribe_and_diarize_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True):
73+
"""
74+
Transcribe and diarize an audio file using Whisper-JAX and pyannote.
75+
"""
76+
command = f"""
77+
import soundfile as sf
78+
import torch
79+
from whisper_jax import FlaxWhisperPipline
80+
import jax.numpy as jnp
81+
from pyannote.audio import Pipeline
82+
import numpy as np
83+
84+
def align(transcription, segments, group_by_speaker=True):
85+
transcription_split = transcription.split("\\n")
86+
transcript = []
87+
for chunk in transcription_split:
88+
start_end, text = chunk[1:].split("] ")
89+
start, end = start_end.split("->")
90+
start, end = float(start), float(end)
91+
transcript.append({{"timestamp": (start, end), "text": text}})
92+
93+
new_segments = []
94+
prev_segment = segments[0]
95+
for i in range(1, len(segments)):
96+
cur_segment = segments[i]
97+
if cur_segment["label"] != prev_segment["label"]:
98+
new_segments.append({{
99+
"segment": {{"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}},
100+
"speaker": prev_segment["label"]
101+
}})
102+
prev_segment = segments[i]
103+
new_segments.append({{
104+
"segment": {{"start": prev_segment["segment"]["start"], "end": segments[-1]["segment"]["end"]}},
105+
"speaker": prev_segment["label"]
106+
}})
107+
108+
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
109+
segmented_preds = []
110+
111+
for segment in new_segments:
112+
end_time = segment["segment"]["end"]
113+
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
114+
115+
if group_by_speaker:
116+
segmented_preds.append({{
117+
"speaker": segment["speaker"],
118+
"text": " ".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
119+
"timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1])
120+
}})
121+
else:
122+
for i in range(upto_idx + 1):
123+
segmented_preds.append({{"speaker": segment["speaker"], **transcript[i]}})
124+
125+
transcript = transcript[upto_idx + 1 :]
126+
end_timestamps = end_timestamps[upto_idx + 1 :]
127+
128+
return segmented_preds
129+
130+
pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16)
131+
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="{self.hf_token}")
132+
133+
audio, sr = sf.read("{file_path}")
134+
inputs = {{"array": audio, "sampling_rate": sr}}
135+
136+
# Transcribe
137+
result = pipeline(inputs, task="{task}", return_timestamps={return_timestamps})
138+
139+
# Diarize
140+
diarization = diarization_pipeline({{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr}})
141+
segments = diarization.for_json()["content"]
142+
143+
# Align transcription and diarization
144+
aligned_result = align(result["text"], segments, group_by_speaker={group_by_speaker})
145+
print(aligned_result)
146+
"""
147+
return self.run_command(command)
148+
149+
def transcribe_and_diarize_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True):
150+
"""
151+
Transcribe and diarize a YouTube video using Whisper-JAX and pyannote.
152+
"""
153+
command = f"""
154+
import tempfile
155+
import youtube_dl
156+
import soundfile as sf
157+
import torch
158+
from whisper_jax import FlaxWhisperPipline
159+
import jax.numpy as jnp
160+
from pyannote.audio import Pipeline
161+
import numpy as np
162+
163+
def download_youtube_audio(youtube_url, output_file):
164+
ydl_opts = {{
165+
'format': 'bestaudio/best',
166+
'postprocessors': [{{
167+
'key': 'FFmpegExtractAudio',
168+
'preferredcodec': 'wav',
169+
'preferredquality': '192',
170+
}}],
171+
'outtmpl': output_file,
172+
}}
173+
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
174+
ydl.download([youtube_url])
175+
176+
def align(transcription, segments, group_by_speaker=True):
177+
transcription_split = transcription.split("\\n")
178+
transcript = []
179+
for chunk in transcription_split:
180+
start_end, text = chunk[1:].split("] ")
181+
start, end = start_end.split("->")
182+
start, end = float(start), float(end)
183+
transcript.append({{"timestamp": (start, end), "text": text}})
184+
185+
new_segments = []
186+
prev_segment = segments[0]
187+
for i in range(1, len(segments)):
188+
cur_segment = segments[i]
189+
if cur_segment["label"] != prev_segment["label"]:
190+
new_segments.append({{
191+
"segment": {{"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}},
192+
"speaker": prev_segment["label"]
193+
}})
194+
prev_segment = segments[i]
195+
new_segments.append({{
196+
"segment": {{"start": prev_segment["segment"]["start"], "end": segments[-1]["segment"]["end"]}},
197+
"speaker": prev_segment["label"]
198+
}})
199+
200+
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
201+
segmented_preds = []
202+
203+
for segment in new_segments:
204+
end_time = segment["segment"]["end"]
205+
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
206+
207+
if group_by_speaker:
208+
segmented_preds.append({{
209+
"speaker": segment["speaker"],
210+
"text": " ".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
211+
"timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1])
212+
}})
213+
else:
214+
for i in range(upto_idx + 1):
215+
segmented_preds.append({{"speaker": segment["speaker"], **transcript[i]}})
216+
217+
transcript = transcript[upto_idx + 1 :]
218+
end_timestamps = end_timestamps[upto_idx + 1 :]
219+
220+
return segmented_preds
221+
222+
pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16)
223+
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="{self.hf_token}")
224+
225+
with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file:
226+
download_youtube_audio("{youtube_url}", temp_file.name)
227+
audio, sr = sf.read(temp_file.name)
228+
inputs = {{"array": audio, "sampling_rate": sr}}
229+
230+
# Transcribe
231+
result = pipeline(inputs, task="{task}", return_timestamps={return_timestamps})
232+
233+
# Diarize
234+
diarization = diarization_pipeline({{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr}})
235+
segments = diarization.for_json()["content"]
236+
237+
# Align transcription and diarization
238+
aligned_result = align(result["text"], segments, group_by_speaker={group_by_speaker})
239+
print(aligned_result)
240+
"""
241+
return self.run_command(command)
242+
243+
def start(self):
244+
"""
245+
Start the JAX-Whisper-Diarization container.
246+
"""
247+
super().start()
248+
logging.info(f"JAX-Whisper-Diarization container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}")
249+
return self

0 commit comments

Comments
 (0)