Skip to content

Commit 8cdeb11

Browse files
committed
v0.36.0
See https://github.com/quic/ai-hub-models/releases/v0.36.0 for changelog. Signed-off-by: QAIHM Team <[email protected] >
1 parent 7385ead commit 8cdeb11

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1923
-2236
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ and many more.
289289
| | |
290290
| **Speech Recognition**
291291
| [HuggingFace-WavLM-Base-Plus](https://aihub.qualcomm.com/models/huggingface_wavlm_base_plus) | [qai_hub_models.models.huggingface_wavlm_base_plus](qai_hub_models/models/huggingface_wavlm_base_plus/README.md) |
292-
| [Whisper-Base-En](https://aihub.qualcomm.com/models/whisper_base_en) | [qai_hub_models.models.whisper_base_en](qai_hub_models/models/whisper_base_en/README.md) |
292+
| [Whisper-Base](https://aihub.qualcomm.com/models/whisper_base) | [qai_hub_models.models.whisper_base](qai_hub_models/models/whisper_base/README.md) |
293293
| [Whisper-Large-V3-Turbo](https://aihub.qualcomm.com/models/whisper_large_v3_turbo) | [qai_hub_models.models.whisper_large_v3_turbo](qai_hub_models/models/whisper_large_v3_turbo/README.md) |
294-
| [Whisper-Small-En](https://aihub.qualcomm.com/models/whisper_small_en) | [qai_hub_models.models.whisper_small_en](qai_hub_models/models/whisper_small_en/README.md) |
295-
| [Whisper-Tiny-En](https://aihub.qualcomm.com/models/whisper_tiny_en) | [qai_hub_models.models.whisper_tiny_en](qai_hub_models/models/whisper_tiny_en/README.md) |
294+
| [Whisper-Small](https://aihub.qualcomm.com/models/whisper_small) | [qai_hub_models.models.whisper_small](qai_hub_models/models/whisper_small/README.md) |
295+
| [Whisper-Tiny](https://aihub.qualcomm.com/models/whisper_tiny) | [qai_hub_models.models.whisper_tiny](qai_hub_models/models/whisper_tiny/README.md) |
296296
| | |
297297
| **Audio Classification**
298298
| [YamNet](https://aihub.qualcomm.com/models/yamnet) | [qai_hub_models.models.yamnet](qai_hub_models/models/yamnet/README.md) |

qai_hub_models/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
# ---------------------------------------------------------------------
55

6-
__version__ = "0.35.0"
6+
__version__ = "0.36.0"

qai_hub_models/global_requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ mypy==1.13.0
4545
numba==0.60.0
4646
numpy<2
4747
object-detection-metrics==0.4.post1
48-
onnx>=1.16.1
49-
onnxruntime>=1.19
48+
onnx>=1.16.1,<1.20
49+
onnxruntime>=1.19,<1.23
5050
onnxsim<=0.4.36;python_version<'3.12'
5151
onnxsim-prebuilt==0.4.36.post1;python_version>='3.12'
5252
opencv-python>4,<5
@@ -82,6 +82,7 @@ seaborn==0.11.0
8282
segment-anything==1.0
8383
sentencepiece==0.2.0
8484
shapely==2.0.3
85+
sounddevice==0.5.2
8586
soundfile==0.13.1
8687
stringcase==1.2.0
8788
supervision==0.25.1

qai_hub_models/models/_shared/hf_whisper/app.py

Lines changed: 123 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import numpy as np
9+
import sounddevice as sd
910
import torch
1011
from scipy.signal import resample_poly
1112
from transformers.models.whisper import WhisperConfig
@@ -56,6 +57,7 @@ def __init__(
5657

5758
self.feature_extractor = get_feature_extractor(hf_model_id)
5859
self.tokenizer = get_tokenizer(hf_model_id)
60+
self.clip_segment_tokens = set(self.tokenizer.all_special_ids)
5961

6062
def predict(self, *args, **kwargs):
6163
# See transcribe.
@@ -82,23 +84,10 @@ def transcribe(
8284
-------
8385
List of audio arrays, chunked into N arrays of model_chunk_seconds seconds.
8486
"""
85-
if isinstance(audio, str):
86-
import audio2numpy as a2n # import here, as this requires ffmpeg to be installed on host machine
87-
88-
audio, audio_sample_rate = a2n.audio_from_file(audio)
89-
else:
90-
assert audio_sample_rate is not None
91-
assert isinstance(audio, np.ndarray)
92-
assert isinstance(audio_sample_rate, int)
93-
with torch.no_grad():
94-
trans = " ".join(
95-
self._transcribe_single_chunk(x)
96-
for x in chunk_and_resample_audio(audio, audio_sample_rate)
97-
)
98-
99-
return trans
87+
tokens = self.transcribe_tokens(audio, audio_sample_rate)
88+
return self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
10089

101-
def _transcribe_single_chunk(self, audio: np.ndarray) -> str:
90+
def _transcribe_single_chunk(self, audio: np.ndarray) -> list[int]:
10291
"""
10392
Transcribe an audio chunk to text.
10493
@@ -110,8 +99,7 @@ def _transcribe_single_chunk(self, audio: np.ndarray) -> str:
11099
The maximum length of this audio must be self.max_audio_samples.
111100
112101
Returns:
113-
114-
- transcribed texts
102+
list of token ids
115103
"""
116104
# feature
117105
input_features = self.feature_extractor(
@@ -213,8 +201,123 @@ def _transcribe_single_chunk(self, audio: np.ndarray) -> str:
213201
# update position_ids
214202
position_ids += 1
215203

216-
# Exclude start / end tokens
217-
return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
204+
return output_ids[0].tolist()
205+
206+
def stream(self, device=2, audio_chunk_size_seconds: int = 5) -> None:
207+
"""
208+
Stream audio from the given audio device and transcribe in real time.
209+
210+
Parameters:
211+
device:
212+
Audio device (see. sounddevice.query_devices())
213+
audio_chunk_size_seconds:
214+
Number of seconds to record between each transcription attempt.
215+
"""
216+
tokens: list[int] = []
217+
218+
def callback(audio: np.ndarray, frames, time, status):
219+
nonlocal tokens
220+
curr_tokens = self.transcribe_tokens(audio.squeeze(-1), SAMPLE_RATE)
221+
tokens.extend(curr_tokens)
222+
223+
if not curr_tokens:
224+
# This audio was empty, so it's safe to decode previous tokens.
225+
print(
226+
self.tokenizer.decode(tokens, skip_special_tokens=True),
227+
end="",
228+
flush=True,
229+
)
230+
tokens = []
231+
else:
232+
split_start = 0
233+
decode_splits = []
234+
token_idx = 0
235+
# Every time 2 "clip segment tokens" (timestamp tokens)
236+
# appear in sequence, we're safe to decode the previous tokens.
237+
while token_idx < len(tokens):
238+
if tokens[token_idx] in self.clip_segment_tokens:
239+
next_non_clip_idx = token_idx + 1
240+
while (
241+
next_non_clip_idx < len(tokens)
242+
and tokens[next_non_clip_idx] in self.clip_segment_tokens
243+
):
244+
next_non_clip_idx = next_non_clip_idx + 1
245+
246+
if next_non_clip_idx >= token_idx + 2:
247+
split_end = token_idx + 1
248+
if max(split_end - split_start, 0) > 0:
249+
decode_splits.append((split_start, split_end))
250+
split_start = next_non_clip_idx
251+
252+
token_idx = next_non_clip_idx + 1
253+
else:
254+
token_idx = token_idx + 1
255+
256+
for split in decode_splits:
257+
print(
258+
self.tokenizer.decode(
259+
tokens[split[0] : split[1]], skip_special_tokens=True
260+
),
261+
end="",
262+
flush=True,
263+
)
264+
if split_start != 0:
265+
tokens = tokens[split_start:]
266+
267+
print("Listening...")
268+
print("Text can take up to 20 seconds before printing.")
269+
with sd.InputStream(
270+
device=device,
271+
channels=1,
272+
blocksize=audio_chunk_size_seconds * SAMPLE_RATE,
273+
callback=callback,
274+
samplerate=SAMPLE_RATE,
275+
):
276+
while True:
277+
response = input("Press ctrl+c or q/Q to quit.\n")
278+
if response in ("q", "Q"):
279+
break
280+
281+
def transcribe_tokens(
282+
self, audio: np.ndarray | str, audio_sample_rate: int | None = None
283+
) -> list[int]:
284+
"""
285+
Transcribe the provided audio to text.
286+
287+
Parameters
288+
----------
289+
audio: numpy array | str
290+
Path to audio file if a string.
291+
Raw audio array of shape (# of samples) if a numpy array.
292+
293+
audio_sample_rate: int | None
294+
The sample rate of the provided audio, in samples / second.
295+
If audio is a numpy array, this must be provided.
296+
If audio is a file and audio_sample_rate is None, this is ignored and the sample rate will be derived from the audio file.
297+
298+
Returns
299+
-------
300+
transcribed tokens
301+
"""
302+
if isinstance(audio, str):
303+
import audio2numpy as a2n # import here, as this requires ffmpeg to be installed on host machine
304+
305+
audio, audio_sample_rate = a2n.audio_from_file(audio)
306+
if isinstance(audio, np.ndarray) and audio.ndim == 2:
307+
# Audio is multi-channel (e.g., stero); collapse to single.
308+
audio = audio.mean(-1)
309+
310+
assert audio_sample_rate is not None
311+
assert isinstance(audio, np.ndarray)
312+
313+
out_chunked_tokens: list[list[int]] = [
314+
self._transcribe_single_chunk(x)
315+
for x in chunk_and_resample_audio(audio, audio_sample_rate)
316+
]
317+
out_tokens: list[int] = []
318+
for chunk_tokens in out_chunked_tokens:
319+
out_tokens.extend(chunk_tokens)
320+
return out_tokens
218321

219322

220323
def chunk_and_resample_audio(

qai_hub_models/models/_shared/hf_whisper/demo.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
# ---------------------------------------------------------------------
55

6-
76
import numpy as np
87

98
from qai_hub_models.models._shared.hf_whisper.app import HfWhisperApp
@@ -30,22 +29,39 @@ def load_demo_audio() -> tuple[np.ndarray, int]:
3029
def hf_whisper_demo(model_cls: type[HfWhisper], is_test: bool = False) -> None:
3130
parser = get_model_cli_parser(model_cls)
3231
parser.add_argument(
33-
"--audio_file",
32+
"--audio-file",
3433
type=str,
3534
default=None,
3635
help="Audio file path or URL",
3736
)
37+
parser.add_argument(
38+
"--stream-audio-device",
39+
type=int,
40+
default=None,
41+
help="Audio device (number) to stream from.",
42+
)
43+
parser.add_argument(
44+
"--stream-audio-chunk-size",
45+
type=int,
46+
default=10,
47+
help="For audio streaming, the number of seconds to record between each transcription attempt. A minimum of around 10 seconds is recommended for best accuracy.",
48+
)
3849
args = parser.parse_args([] if is_test else None)
50+
if (args.stream_audio_device is not None) and (args.audio_file is not None):
51+
raise ValueError("Cannot set both audio-file and stream-audio-device")
3952

4053
model = model_cls.from_pretrained()
4154
app = HfWhisperApp(model.encoder, model.decoder, model_cls.get_hf_whisper_version())
4255

43-
# Load default audio if file not provided
44-
audio = args.audio_file
45-
audio_sample_rate = None
46-
if not audio:
47-
audio, audio_sample_rate = load_demo_audio()
56+
if args.stream_audio_device:
57+
app.stream(args.stream_audio_device, args.stream_audio_chunk_size)
58+
else:
59+
# Load default audio if file not provided
60+
audio = args.audio_file
61+
audio_sample_rate = None
62+
if not audio:
63+
audio, audio_sample_rate = load_demo_audio()
4864

49-
# Perform transcription
50-
transcription = app.transcribe(audio, audio_sample_rate)
51-
print("Transcription:", transcription)
65+
# Perform transcription
66+
transcription = app.transcribe(audio, audio_sample_rate)
67+
print("Transcription:", transcription)

qai_hub_models/models/_shared/hf_whisper/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def run_test_transcribe(
145145
model = WhisperForConditionalGeneration.from_pretrained(hf_whisper_version)
146146
predicted_ids = model.generate(mel_input)
147147
tokenizer = WhisperTokenizer.from_pretrained(hf_whisper_version)
148-
text_orig = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
148+
text_orig = tokenizer.decode(predicted_ids[0], skip_special_tokens=True).strip()
149149

150150
# Perform transcription
151151
transcription = app.transcribe(audio, sample_rate)

0 commit comments

Comments
 (0)