@@ -59,7 +59,7 @@ def subpipe_maker(framerate_ratio):
5959 return subpipe_maker (scale_factor )
6060
6161
62- def _make_auditok_detector (sample_rate , frame_rate ):
62+ def _make_auditok_detector (sample_rate , frame_rate , non_speech_label ):
6363 try :
6464 from auditok import \
6565 BufferAudioSource , ADSFactory , AudioEnergyValidator , StreamTokenizer
@@ -76,31 +76,37 @@ def _make_auditok_detector(sample_rate, frame_rate):
7676 bytes_per_frame = 2
7777 frames_per_window = frame_rate // sample_rate
7878 validator = AudioEnergyValidator (
79- sample_width = bytes_per_frame , energy_threshold = 50 )
79+ sample_width = bytes_per_frame , energy_threshold = 50
80+ )
8081 tokenizer = StreamTokenizer (
81- validator = validator , min_length = 0.2 * sample_rate ,
82- max_length = int (5 * sample_rate ),
83- max_continuous_silence = 0.25 * sample_rate )
82+ validator = validator ,
83+ min_length = 0.2 * sample_rate ,
84+ max_length = int (5 * sample_rate ),
85+ max_continuous_silence = 0.25 * sample_rate
86+ )
8487
8588 def _detect (asegment ):
86- asource = BufferAudioSource (data_buffer = asegment ,
87- sampling_rate = frame_rate ,
88- sample_width = bytes_per_frame ,
89- channels = 1 )
89+ asource = BufferAudioSource (
90+ data_buffer = asegment ,
91+ sampling_rate = frame_rate ,
92+ sample_width = bytes_per_frame ,
93+ channels = 1
94+ )
9095 ads = ADSFactory .ads (audio_source = asource , block_dur = 1. / sample_rate )
9196 ads .open ()
9297 tokens = tokenizer .tokenize (ads )
93- length = (len (asegment )// bytes_per_frame
94- + frames_per_window - 1 )// frames_per_window
95- media_bstring = np .zeros (length + 1 , dtype = int )
98+ length = (
99+ len (asegment )// bytes_per_frame + frames_per_window - 1
100+ ) // frames_per_window
101+ media_bstring = np .zeros (length + 1 )
96102 for token in tokens :
97- media_bstring [token [1 ]] + = 1
98- media_bstring [token [2 ]+ 1 ] -= 1
99- return (np .cumsum (media_bstring )[:- 1 ] > 0 ). astype ( float )
103+ media_bstring [token [1 ]] = 1.
104+ media_bstring [token [2 ] + 1 ] = non_speech_label - 1.
105+ return np . clip (np .cumsum (media_bstring )[:- 1 ], 0. , 1. )
100106 return _detect
101107
102108
103- def _make_webrtcvad_detector (sample_rate , frame_rate ):
109+ def _make_webrtcvad_detector (sample_rate , frame_rate , non_speech_label ):
104110 import webrtcvad
105111 vad = webrtcvad .Vad ()
106112 vad .set_mode (3 ) # set non-speech pruning aggressiveness from 0 to 3
@@ -123,7 +129,7 @@ def _detect(asegment):
123129 is_speech = False
124130 failures += 1
125131 # webrtcvad has low recall on mode 3, so treat non-speech as "not sure"
126- media_bstring .append (1. if is_speech else 0.5 )
132+ media_bstring .append (1. if is_speech else non_speech_label )
127133 return np .array (media_bstring )
128134
129135 return _detect
@@ -141,20 +147,23 @@ def num_frames(self):
141147 return self .end_frame_ - self .start_frame_
142148
143149 def fit_boundaries (self , speech_frames ):
144- nz = np .nonzero (speech_frames )[0 ]
150+ nz = np .nonzero (speech_frames > 0.5 )[0 ]
145151 if len (nz ) > 0 :
146152 self .start_frame_ = np .min (nz )
147153 self .end_frame_ = np .max (nz )
148154 return self
149155
150156
151- class VideoSpeechTransformer (TransformerMixin , ComputeSpeechFrameBoundariesMixin ):
152- def __init__ (self , vad , sample_rate , frame_rate , start_seconds = 0 ,
153- ffmpeg_path = None , ref_stream = None , vlc_mode = False , gui_mode = False ):
157+ class VideoSpeechTransformer (TransformerMixin ):
158+ def __init__ (
159+ self , vad , sample_rate , frame_rate , non_speech_label , start_seconds = 0 ,
160+ ffmpeg_path = None , ref_stream = None , vlc_mode = False , gui_mode = False
161+ ):
154162 super (VideoSpeechTransformer , self ).__init__ ()
155163 self .vad = vad
156164 self .sample_rate = sample_rate
157165 self .frame_rate = frame_rate
166+ self ._non_speech_label = non_speech_label
158167 self .start_seconds = start_seconds
159168 self .ffmpeg_path = ffmpeg_path
160169 self .ref_stream = ref_stream
@@ -197,7 +206,6 @@ def try_fit_using_embedded_subs(self, fname):
197206 # use longest set of embedded subs
198207 subs_to_use = embedded_subs [int (np .argmax (embedded_subs_times ))]
199208 self .video_speech_results_ = subs_to_use .subtitle_speech_results_
200- self .fit_boundaries (self .video_speech_results_ )
201209
202210 def fit (self , fname , * _ ):
203211 if 'subs' in self .vad and (self .ref_stream is None or self .ref_stream .startswith ('0:s:' )):
@@ -216,9 +224,9 @@ def fit(self, fname, *_):
216224 logger .warning (e )
217225 total_duration = None
218226 if 'webrtc' in self .vad :
219- detector = _make_webrtcvad_detector (self .sample_rate , self .frame_rate )
227+ detector = _make_webrtcvad_detector (self .sample_rate , self .frame_rate , self . _non_speech_label )
220228 elif 'auditok' in self .vad :
221- detector = _make_auditok_detector (self .sample_rate , self .frame_rate )
229+ detector = _make_auditok_detector (self .sample_rate , self .frame_rate , self . _non_speech_label )
222230 else :
223231 raise ValueError ('unknown vad: %s' % self .vad )
224232 media_bstring = []
@@ -284,7 +292,6 @@ def redirect_stderr(enter_result=None):
284292 'Unable to detect speech. Perhaps try specifying a different stream / track, or a different vad.'
285293 )
286294 self .video_speech_results_ = np .concatenate (media_bstring )
287- self .fit_boundaries (self .video_speech_results_ )
288295 return self
289296
290297 def transform (self , * _ ):
@@ -300,6 +307,7 @@ def transform(self, *_):
300307}
301308
302309
310+ # TODO: need way better metadata detector
303311def _is_metadata (content , is_beginning_or_end ):
304312 content = content .strip ()
305313 if len (content ) == 0 :
@@ -348,9 +356,10 @@ def transform(self, *_):
348356 return self .subtitle_speech_results_
349357
350358
351- class DeserializeSpeechTransformer (TransformerMixin , ComputeSpeechFrameBoundariesMixin ):
352- def __init__ (self ):
359+ class DeserializeSpeechTransformer (TransformerMixin ):
360+ def __init__ (self , non_speech_label ):
353361 super (DeserializeSpeechTransformer , self ).__init__ ()
362+ self ._non_speech_label = non_speech_label
354363 self .deserialized_speech_results_ = None
355364
356365 def fit (self , fname , * _ ):
@@ -361,8 +370,8 @@ def fit(self, fname, *_):
361370 else :
362371 raise ValueError ('could not find "speech" array in '
363372 'serialized file; only contains: %s' % speech .files )
373+ speech [speech < 1. ] = self ._non_speech_label
364374 self .deserialized_speech_results_ = speech
365- self .fit_boundaries (self .deserialized_speech_results_ )
366375 return self
367376
368377 def transform (self , * _ ):
0 commit comments