|
| 1 | +import logging |
| 2 | +import tempfile |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +from testcontainers.core.container import DockerContainer |
| 6 | +from testcontainers.core.waiting_utils import wait_container_is_ready |
| 7 | +from urllib.error import URLError |
| 8 | + |
| 9 | +class JAXWhisperDiarizationContainer(DockerContainer): |
| 10 | + """ |
| 11 | + JAX-Whisper-Diarization container for fast speech recognition, transcription, and speaker diarization. |
| 12 | +
|
| 13 | + Example: |
| 14 | +
|
| 15 | + .. doctest:: |
| 16 | +
|
| 17 | + >>> logging.basicConfig(level=logging.INFO) |
| 18 | + |
| 19 | + ... # You need to provide your Hugging Face token to use the pyannote.audio models |
| 20 | + >>> hf_token = "your_huggingface_token_here" |
| 21 | + |
| 22 | + >>> with JAXWhisperDiarizationContainer(hf_token=hf_token) as whisper_diarization: |
| 23 | + ... whisper_diarization.connect() |
| 24 | + ... |
| 25 | + ... # Example: Transcribe and diarize an audio file |
| 26 | + ... result = whisper_diarization.transcribe_and_diarize_file("/path/to/audio/file.wav") |
| 27 | + ... print(f"Transcription and Diarization: {result}") |
| 28 | + ... |
| 29 | + ... # Example: Transcribe and diarize a YouTube video |
| 30 | + ... result = whisper_diarization.transcribe_and_diarize_youtube("https://www.youtube.com/watch?v=dQw4w9WgXcQ") |
| 31 | + ... print(f"YouTube Transcription and Diarization: {result}") |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__(self, model_name: str = "openai/whisper-large-v2", hf_token: Optional[str] = None, **kwargs): |
| 35 | + super().__init__("nvcr.io/nvidia/jax:23.08-py3", **kwargs) |
| 36 | + self.model_name = model_name |
| 37 | + self.hf_token = hf_token |
| 38 | + self.with_exposed_ports(8888) # Expose Jupyter notebook port |
| 39 | + self.with_env("NVIDIA_VISIBLE_DEVICES", "all") |
| 40 | + self.with_env("CUDA_VISIBLE_DEVICES", "all") |
| 41 | + self.with_kwargs(runtime="nvidia") # Use NVIDIA runtime for GPU support |
| 42 | + |
| 43 | + # Install required dependencies |
| 44 | + self.with_command("sh -c '" |
| 45 | + "pip install --no-cache-dir git+https://github.com/sanchit-gandhi/whisper-jax.git && " |
| 46 | + "pip install --no-cache-dir numpy soundfile youtube_dl transformers datasets pyannote.audio && " |
| 47 | + "python -m pip install --upgrade --no-cache-dir jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && " |
| 48 | + "jupyter notebook --ip 0.0.0.0 --port 8888 --allow-root --NotebookApp.token='' --NotebookApp.password=''" |
| 49 | + "'") |
| 50 | + |
| 51 | + @wait_container_is_ready(URLError) |
| 52 | + def _connect(self): |
| 53 | + url = f"http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}" |
| 54 | + res = urllib.request.urlopen(url) |
| 55 | + if res.status != 200: |
| 56 | + raise Exception(f"Failed to connect to JAX-Whisper-Diarization container. Status: {res.status}") |
| 57 | + |
| 58 | + def connect(self): |
| 59 | + """ |
| 60 | + Connect to the JAX-Whisper-Diarization container and ensure it's ready. |
| 61 | + """ |
| 62 | + self._connect() |
| 63 | + logging.info("Successfully connected to JAX-Whisper-Diarization container") |
| 64 | + |
| 65 | + def run_command(self, command: str): |
| 66 | + """ |
| 67 | + Run a Python command inside the container. |
| 68 | + """ |
| 69 | + exec_result = self.exec(f"python -c '{command}'") |
| 70 | + return exec_result |
| 71 | + |
| 72 | + def transcribe_and_diarize_file(self, file_path: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True): |
| 73 | + """ |
| 74 | + Transcribe and diarize an audio file using Whisper-JAX and pyannote. |
| 75 | + """ |
| 76 | + command = f""" |
| 77 | +import soundfile as sf |
| 78 | +import torch |
| 79 | +from whisper_jax import FlaxWhisperPipline |
| 80 | +import jax.numpy as jnp |
| 81 | +from pyannote.audio import Pipeline |
| 82 | +import numpy as np |
| 83 | +
|
| 84 | +def align(transcription, segments, group_by_speaker=True): |
| 85 | + transcription_split = transcription.split("\\n") |
| 86 | + transcript = [] |
| 87 | + for chunk in transcription_split: |
| 88 | + start_end, text = chunk[1:].split("] ") |
| 89 | + start, end = start_end.split("->") |
| 90 | + start, end = float(start), float(end) |
| 91 | + transcript.append({{"timestamp": (start, end), "text": text}}) |
| 92 | +
|
| 93 | + new_segments = [] |
| 94 | + prev_segment = segments[0] |
| 95 | + for i in range(1, len(segments)): |
| 96 | + cur_segment = segments[i] |
| 97 | + if cur_segment["label"] != prev_segment["label"]: |
| 98 | + new_segments.append({{ |
| 99 | + "segment": {{"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}}, |
| 100 | + "speaker": prev_segment["label"] |
| 101 | + }}) |
| 102 | + prev_segment = segments[i] |
| 103 | + new_segments.append({{ |
| 104 | + "segment": {{"start": prev_segment["segment"]["start"], "end": segments[-1]["segment"]["end"]}}, |
| 105 | + "speaker": prev_segment["label"] |
| 106 | + }}) |
| 107 | +
|
| 108 | + end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) |
| 109 | + segmented_preds = [] |
| 110 | +
|
| 111 | + for segment in new_segments: |
| 112 | + end_time = segment["segment"]["end"] |
| 113 | + upto_idx = np.argmin(np.abs(end_timestamps - end_time)) |
| 114 | +
|
| 115 | + if group_by_speaker: |
| 116 | + segmented_preds.append({{ |
| 117 | + "speaker": segment["speaker"], |
| 118 | + "text": " ".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), |
| 119 | + "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]) |
| 120 | + }}) |
| 121 | + else: |
| 122 | + for i in range(upto_idx + 1): |
| 123 | + segmented_preds.append({{"speaker": segment["speaker"], **transcript[i]}}) |
| 124 | +
|
| 125 | + transcript = transcript[upto_idx + 1 :] |
| 126 | + end_timestamps = end_timestamps[upto_idx + 1 :] |
| 127 | +
|
| 128 | + return segmented_preds |
| 129 | +
|
| 130 | +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) |
| 131 | +diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="{self.hf_token}") |
| 132 | +
|
| 133 | +audio, sr = sf.read("{file_path}") |
| 134 | +inputs = {{"array": audio, "sampling_rate": sr}} |
| 135 | +
|
| 136 | +# Transcribe |
| 137 | +result = pipeline(inputs, task="{task}", return_timestamps={return_timestamps}) |
| 138 | +
|
| 139 | +# Diarize |
| 140 | +diarization = diarization_pipeline({{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr}}) |
| 141 | +segments = diarization.for_json()["content"] |
| 142 | +
|
| 143 | +# Align transcription and diarization |
| 144 | +aligned_result = align(result["text"], segments, group_by_speaker={group_by_speaker}) |
| 145 | +print(aligned_result) |
| 146 | +""" |
| 147 | + return self.run_command(command) |
| 148 | + |
| 149 | + def transcribe_and_diarize_youtube(self, youtube_url: str, task: str = "transcribe", return_timestamps: bool = True, group_by_speaker: bool = True): |
| 150 | + """ |
| 151 | + Transcribe and diarize a YouTube video using Whisper-JAX and pyannote. |
| 152 | + """ |
| 153 | + command = f""" |
| 154 | +import tempfile |
| 155 | +import youtube_dl |
| 156 | +import soundfile as sf |
| 157 | +import torch |
| 158 | +from whisper_jax import FlaxWhisperPipline |
| 159 | +import jax.numpy as jnp |
| 160 | +from pyannote.audio import Pipeline |
| 161 | +import numpy as np |
| 162 | +
|
| 163 | +def download_youtube_audio(youtube_url, output_file): |
| 164 | + ydl_opts = {{ |
| 165 | + 'format': 'bestaudio/best', |
| 166 | + 'postprocessors': [{{ |
| 167 | + 'key': 'FFmpegExtractAudio', |
| 168 | + 'preferredcodec': 'wav', |
| 169 | + 'preferredquality': '192', |
| 170 | + }}], |
| 171 | + 'outtmpl': output_file, |
| 172 | + }} |
| 173 | + with youtube_dl.YoutubeDL(ydl_opts) as ydl: |
| 174 | + ydl.download([youtube_url]) |
| 175 | +
|
| 176 | +def align(transcription, segments, group_by_speaker=True): |
| 177 | + transcription_split = transcription.split("\\n") |
| 178 | + transcript = [] |
| 179 | + for chunk in transcription_split: |
| 180 | + start_end, text = chunk[1:].split("] ") |
| 181 | + start, end = start_end.split("->") |
| 182 | + start, end = float(start), float(end) |
| 183 | + transcript.append({{"timestamp": (start, end), "text": text}}) |
| 184 | +
|
| 185 | + new_segments = [] |
| 186 | + prev_segment = segments[0] |
| 187 | + for i in range(1, len(segments)): |
| 188 | + cur_segment = segments[i] |
| 189 | + if cur_segment["label"] != prev_segment["label"]: |
| 190 | + new_segments.append({{ |
| 191 | + "segment": {{"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]}}, |
| 192 | + "speaker": prev_segment["label"] |
| 193 | + }}) |
| 194 | + prev_segment = segments[i] |
| 195 | + new_segments.append({{ |
| 196 | + "segment": {{"start": prev_segment["segment"]["start"], "end": segments[-1]["segment"]["end"]}}, |
| 197 | + "speaker": prev_segment["label"] |
| 198 | + }}) |
| 199 | +
|
| 200 | + end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript]) |
| 201 | + segmented_preds = [] |
| 202 | +
|
| 203 | + for segment in new_segments: |
| 204 | + end_time = segment["segment"]["end"] |
| 205 | + upto_idx = np.argmin(np.abs(end_timestamps - end_time)) |
| 206 | +
|
| 207 | + if group_by_speaker: |
| 208 | + segmented_preds.append({{ |
| 209 | + "speaker": segment["speaker"], |
| 210 | + "text": " ".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]), |
| 211 | + "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]) |
| 212 | + }}) |
| 213 | + else: |
| 214 | + for i in range(upto_idx + 1): |
| 215 | + segmented_preds.append({{"speaker": segment["speaker"], **transcript[i]}}) |
| 216 | +
|
| 217 | + transcript = transcript[upto_idx + 1 :] |
| 218 | + end_timestamps = end_timestamps[upto_idx + 1 :] |
| 219 | +
|
| 220 | + return segmented_preds |
| 221 | +
|
| 222 | +pipeline = FlaxWhisperPipline("{self.model_name}", dtype=jnp.bfloat16, batch_size=16) |
| 223 | +diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="{self.hf_token}") |
| 224 | +
|
| 225 | +with tempfile.NamedTemporaryFile(suffix=".wav") as temp_file: |
| 226 | + download_youtube_audio("{youtube_url}", temp_file.name) |
| 227 | + audio, sr = sf.read(temp_file.name) |
| 228 | + inputs = {{"array": audio, "sampling_rate": sr}} |
| 229 | +
|
| 230 | + # Transcribe |
| 231 | + result = pipeline(inputs, task="{task}", return_timestamps={return_timestamps}) |
| 232 | +
|
| 233 | + # Diarize |
| 234 | + diarization = diarization_pipeline({{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr}}) |
| 235 | + segments = diarization.for_json()["content"] |
| 236 | +
|
| 237 | + # Align transcription and diarization |
| 238 | + aligned_result = align(result["text"], segments, group_by_speaker={group_by_speaker}) |
| 239 | + print(aligned_result) |
| 240 | +""" |
| 241 | + return self.run_command(command) |
| 242 | + |
| 243 | + def start(self): |
| 244 | + """ |
| 245 | + Start the JAX-Whisper-Diarization container. |
| 246 | + """ |
| 247 | + super().start() |
| 248 | + logging.info(f"JAX-Whisper-Diarization container started. Jupyter URL: http://{self.get_container_host_ip()}:{self.get_exposed_port(8888)}") |
| 249 | + return self |
0 commit comments