66from __future__ import annotations
77
88import numpy as np
9+ import sounddevice as sd
910import torch
1011from scipy .signal import resample_poly
1112from transformers .models .whisper import WhisperConfig
@@ -56,6 +57,7 @@ def __init__(
5657
5758 self .feature_extractor = get_feature_extractor (hf_model_id )
5859 self .tokenizer = get_tokenizer (hf_model_id )
60+ self .clip_segment_tokens = set (self .tokenizer .all_special_ids )
5961
6062 def predict (self , * args , ** kwargs ):
6163 # See transcribe.
@@ -82,23 +84,10 @@ def transcribe(
8284 -------
8385 List of audio arrays, chunked into N arrays of model_chunk_seconds seconds.
8486 """
85- if isinstance (audio , str ):
86- import audio2numpy as a2n # import here, as this requires ffmpeg to be installed on host machine
87-
88- audio , audio_sample_rate = a2n .audio_from_file (audio )
89- else :
90- assert audio_sample_rate is not None
91- assert isinstance (audio , np .ndarray )
92- assert isinstance (audio_sample_rate , int )
93- with torch .no_grad ():
94- trans = " " .join (
95- self ._transcribe_single_chunk (x )
96- for x in chunk_and_resample_audio (audio , audio_sample_rate )
97- )
98-
99- return trans
87+ tokens = self .transcribe_tokens (audio , audio_sample_rate )
88+ return self .tokenizer .decode (tokens , skip_special_tokens = True ).strip ()
10089
101- def _transcribe_single_chunk (self , audio : np .ndarray ) -> str :
90+ def _transcribe_single_chunk (self , audio : np .ndarray ) -> list [ int ] :
10291 """
10392 Transcribe an audio chunk to text.
10493
@@ -110,8 +99,7 @@ def _transcribe_single_chunk(self, audio: np.ndarray) -> str:
11099 The maximum length of this audio must be self.max_audio_samples.
111100
112101 Returns:
113-
114- - transcribed texts
102+ list of token ids
115103 """
116104 # feature
117105 input_features = self .feature_extractor (
@@ -213,8 +201,123 @@ def _transcribe_single_chunk(self, audio: np.ndarray) -> str:
213201 # update position_ids
214202 position_ids += 1
215203
216- # Exclude start / end tokens
217- return self .tokenizer .decode (output_ids [0 ], skip_special_tokens = True )
204+ return output_ids [0 ].tolist ()
205+
206+ def stream (self , device = 2 , audio_chunk_size_seconds : int = 5 ) -> None :
207+ """
208+ Stream audio from the given audio device and transcribe in real time.
209+
210+ Parameters:
211+ device:
212+ Audio device (see. sounddevice.query_devices())
213+ audio_chunk_size_seconds:
214+ Number of seconds to record between each transcription attempt.
215+ """
216+ tokens : list [int ] = []
217+
218+ def callback (audio : np .ndarray , frames , time , status ):
219+ nonlocal tokens
220+ curr_tokens = self .transcribe_tokens (audio .squeeze (- 1 ), SAMPLE_RATE )
221+ tokens .extend (curr_tokens )
222+
223+ if not curr_tokens :
224+ # This audio was empty, so it's safe to decode previous tokens.
225+ print (
226+ self .tokenizer .decode (tokens , skip_special_tokens = True ),
227+ end = "" ,
228+ flush = True ,
229+ )
230+ tokens = []
231+ else :
232+ split_start = 0
233+ decode_splits = []
234+ token_idx = 0
235+ # Every time 2 "clip segment tokens" (timestamp tokens)
236+ # appear in sequence, we're safe to decode the previous tokens.
237+ while token_idx < len (tokens ):
238+ if tokens [token_idx ] in self .clip_segment_tokens :
239+ next_non_clip_idx = token_idx + 1
240+ while (
241+ next_non_clip_idx < len (tokens )
242+ and tokens [next_non_clip_idx ] in self .clip_segment_tokens
243+ ):
244+ next_non_clip_idx = next_non_clip_idx + 1
245+
246+ if next_non_clip_idx >= token_idx + 2 :
247+ split_end = token_idx + 1
248+ if max (split_end - split_start , 0 ) > 0 :
249+ decode_splits .append ((split_start , split_end ))
250+ split_start = next_non_clip_idx
251+
252+ token_idx = next_non_clip_idx + 1
253+ else :
254+ token_idx = token_idx + 1
255+
256+ for split in decode_splits :
257+ print (
258+ self .tokenizer .decode (
259+ tokens [split [0 ] : split [1 ]], skip_special_tokens = True
260+ ),
261+ end = "" ,
262+ flush = True ,
263+ )
264+ if split_start != 0 :
265+ tokens = tokens [split_start :]
266+
267+ print ("Listening..." )
268+ print ("Text can take up to 20 seconds before printing." )
269+ with sd .InputStream (
270+ device = device ,
271+ channels = 1 ,
272+ blocksize = audio_chunk_size_seconds * SAMPLE_RATE ,
273+ callback = callback ,
274+ samplerate = SAMPLE_RATE ,
275+ ):
276+ while True :
277+ response = input ("Press ctrl+c or q/Q to quit.\n " )
278+ if response in ("q" , "Q" ):
279+ break
280+
281+ def transcribe_tokens (
282+ self , audio : np .ndarray | str , audio_sample_rate : int | None = None
283+ ) -> list [int ]:
284+ """
285+ Transcribe the provided audio to text.
286+
287+ Parameters
288+ ----------
289+ audio: numpy array | str
290+ Path to audio file if a string.
291+ Raw audio array of shape (# of samples) if a numpy array.
292+
293+ audio_sample_rate: int | None
294+ The sample rate of the provided audio, in samples / second.
295+ If audio is a numpy array, this must be provided.
296+ If audio is a file and audio_sample_rate is None, this is ignored and the sample rate will be derived from the audio file.
297+
298+ Returns
299+ -------
300+ transcribed tokens
301+ """
302+ if isinstance (audio , str ):
303+ import audio2numpy as a2n # import here, as this requires ffmpeg to be installed on host machine
304+
305+ audio , audio_sample_rate = a2n .audio_from_file (audio )
306+ if isinstance (audio , np .ndarray ) and audio .ndim == 2 :
307+ # Audio is multi-channel (e.g., stero); collapse to single.
308+ audio = audio .mean (- 1 )
309+
310+ assert audio_sample_rate is not None
311+ assert isinstance (audio , np .ndarray )
312+
313+ out_chunked_tokens : list [list [int ]] = [
314+ self ._transcribe_single_chunk (x )
315+ for x in chunk_and_resample_audio (audio , audio_sample_rate )
316+ ]
317+ out_tokens : list [int ] = []
318+ for chunk_tokens in out_chunked_tokens :
319+ out_tokens .extend (chunk_tokens )
320+ return out_tokens
218321
219322
220323def chunk_and_resample_audio (
0 commit comments