1515 print(result.eou_type, result.trailing_duration)
1616"""
1717
18- from dataclasses import dataclass
18+ from dataclasses import dataclass , field
1919from enum import StrEnum
2020from typing import Union
2121
@@ -35,6 +35,15 @@ class EoUType(StrEnum):
3535 NOISE = "noise" # significant trailing region with high energy relative to speech
3636
3737
38+ @dataclass
39+ class TokenSegment :
40+ token : str
41+ start : float # seconds
42+ end : float # seconds
43+ duration : float # seconds
44+ confidence : float
45+
46+
3847@dataclass
3948class EoUClassification :
4049 eou_type : EoUType
@@ -44,6 +53,10 @@ class EoUClassification:
4453 trail_rms_ratio : float
4554 last_token_duration : float
4655 last_token_confidence : float
56+ last_token : str
57+ last_token_gap : float # blank gap (seconds) between last and second-to-last speech token
58+ last_two_phoneme_avg_confidence : float # average confidence of last two alphanumeric tokens
59+ token_segments : list [TokenSegment ] = field (default_factory = list )
4760
4861
4962class EoUClassifier :
@@ -61,6 +74,7 @@ def __init__(self, model_name: str = "facebook/wav2vec2-base-960h", sr: int = SR
6174 self .model .eval ()
6275 self .blank_id = self .processor .tokenizer .pad_token_id
6376 self .vocab = self .processor .tokenizer .get_vocab ()
77+ self .id_to_token = {v : k for k , v in self .vocab .items ()}
6478
6579 def _text_to_tokens (self , text : str ) -> list [int ]:
6680 text = text .upper ().strip ()
@@ -93,32 +107,90 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
93107 aligned_ids = fa_ids [0 ].numpy ()
94108 scores = torch .exp (fa_scores [0 ]).numpy ()
95109
96- speech_end_frame = 0
97- last_token_start = 0
98- last_token_id = - 1
99- for i in range (len (aligned_ids ) - 1 , - 1 , - 1 ):
100- if aligned_ids [i ] != self .blank_id :
101- if last_token_id == - 1 :
102- speech_end_frame = i + 1
103- last_token_id = int (aligned_ids [i ])
104- last_token_start = i
105- elif int (aligned_ids [i ]) != last_token_id :
106- break
107- else :
108- last_token_start = i
109-
110- if last_token_id == - 1 :
110+ segments : list [TokenSegment ] = []
111+ cur_id = - 1
112+ seg_start = 0
113+ for i , aid in enumerate (aligned_ids ):
114+ tid = int (aid )
115+ if tid == self .blank_id :
116+ if cur_id != - 1 :
117+ seg_scores = scores [seg_start :i ]
118+ segments .append (
119+ TokenSegment (
120+ token = self .id_to_token .get (cur_id , f"<id:{ cur_id } >" ),
121+ start = seg_start * frame_duration ,
122+ end = i * frame_duration ,
123+ duration = (i - seg_start ) * frame_duration ,
124+ confidence = float (seg_scores .mean ()),
125+ )
126+ )
127+ cur_id = - 1
128+ elif tid != cur_id :
129+ if cur_id != - 1 :
130+ seg_scores = scores [seg_start :i ]
131+ segments .append (
132+ TokenSegment (
133+ token = self .id_to_token .get (cur_id , f"<id:{ cur_id } >" ),
134+ start = seg_start * frame_duration ,
135+ end = i * frame_duration ,
136+ duration = (i - seg_start ) * frame_duration ,
137+ confidence = float (seg_scores .mean ()),
138+ )
139+ )
140+ cur_id = tid
141+ seg_start = i
142+ # else: same non-blank token continues
143+ if cur_id != - 1 :
144+ seg_scores = scores [seg_start : len (aligned_ids )]
145+ segments .append (
146+ TokenSegment (
147+ token = self .id_to_token .get (cur_id , f"<id:{ cur_id } >" ),
148+ start = seg_start * frame_duration ,
149+ end = len (aligned_ids ) * frame_duration ,
150+ duration = (len (aligned_ids ) - seg_start ) * frame_duration ,
151+ confidence = float (seg_scores .mean ()),
152+ )
153+ )
154+
155+ if not segments :
111156 return {
112157 "speech_end" : 0.0 ,
113158 "last_token_duration" : 0.0 ,
114159 "last_token_confidence" : 0.0 ,
160+ "last_token" : "" ,
161+ "last_token_gap" : 0.0 ,
162+ "last_two_phoneme_avg_confidence" : 0.0 ,
163+ "token_segments" : [],
115164 }
116165
117- last_seg_scores = scores [last_token_start :speech_end_frame ]
166+ last = segments [- 1 ]
167+
168+ # Skip trailing punctuation/non-letter tokens for cutoff analysis,
169+ # since they don't correspond to real speech sounds and get
170+ # unreliably short durations from forced alignment.
171+ last_speech = last
172+ for seg in reversed (segments ):
173+ if seg .token .isalnum ():
174+ last_speech = seg
175+ break
176+
177+ last_idx = segments .index (last_speech )
178+ if last_idx > 0 :
179+ last_token_gap = last_speech .start - segments [last_idx - 1 ].end
180+ else :
181+ last_token_gap = last_speech .start
182+
183+ last_two_alnum = [s for s in segments if s .token .isalnum ()][- 2 :]
184+ last_two_avg = float (np .mean ([s .confidence for s in last_two_alnum ]))
185+
118186 return {
119- "speech_end" : speech_end_frame * frame_duration ,
120- "last_token_duration" : (speech_end_frame - last_token_start ) * frame_duration ,
121- "last_token_confidence" : float (last_seg_scores .mean ()),
187+ "speech_end" : last .end , # + 0.05, # add 50ms of tolerance
188+ "last_token_duration" : last_speech .duration ,
189+ "last_token_confidence" : last_speech .confidence ,
190+ "last_token" : last_speech .token ,
191+ "last_token_gap" : last_token_gap ,
192+ "last_two_phoneme_avg_confidence" : last_two_avg ,
193+ "token_segments" : segments ,
122194 }
123195
124196 def classify (
@@ -146,8 +218,8 @@ def classify(
146218
147219 speech_end = info ["speech_end" ]
148220 trailing = audio_dur - speech_end
149-
150- trail_start = int (speech_end * self .sr )
221+ last_letter_pad = 0.15
222+ trail_start = int (( speech_end + last_letter_pad ) * self .sr )
151223 trailing_audio = samples [trail_start :]
152224 if len (trailing_audio ) > 0 :
153225 rms_trail = np .sqrt (np .mean (trailing_audio ** 2 ))
@@ -158,15 +230,26 @@ def classify(
158230
159231 last_dur = info ["last_token_duration" ]
160232 last_conf = info ["last_token_confidence" ]
161-
162- if trailing < 0.06 and last_dur < 0.025 and last_conf < 0.1 :
163- # speech ends abruptly, with a very short last token and low confidence --> cutoff
233+ if last_conf < 0.01 :
234+ last_conf = info ["last_two_phoneme_avg_confidence" ]
235+ last_tok = info ["last_token" ]
236+ last_gap = info ["last_token_gap" ]
237+ last_two_avg = info ["last_two_phoneme_avg_confidence" ]
238+ token_segments = info ["token_segments" ]
239+
240+ # if trailing < 0.06 and (last_dur < 0.025 and last_conf < 0.15):
241+ # if trailing < 0.06 and (last_gap < 0.1 and last_conf < 0.15): short trail and not due to gap
242+ conf_threshold = 0.07
243+ # short tail with low confidence and not due to gap (which could indicate noise) --> cutoff
244+ if trailing < 0.1 and last_conf < conf_threshold and not last_gap > 0.4 :
245+ # speech ends abruptly, with a very short last token or low confidence --> cutoff
164246 eou_type = EoUType .CUTOFF
165- elif trailing > 0.3 and trail_rms_ratio > 0.5 :
247+ # long noisy tail OR odd gap --> noisy
248+ elif (trailing > 0.25 and trail_rms_ratio > 0.3 ) or (last_gap > 0.4 and last_conf < conf_threshold ):
166249 # significant trailing region with high energy relative to speech --> noise
167250 eou_type = EoUType .NOISE
168- elif trailing > 1.0 and trail_rms_ratio < 0.10 :
169- # very long trailing region with near-zero energy --> silence
251+ # very long trailing region with near-zero energy --> silence
252+ elif trailing > 1.0 : # and trail_rms_ratio < 0.10:
170253 eou_type = EoUType .SILENCE
171254 else :
172255 # everything else (moderate trailing, natural energy decay) --> good
@@ -180,6 +263,10 @@ def classify(
180263 trail_rms_ratio = trail_rms_ratio ,
181264 last_token_duration = last_dur ,
182265 last_token_confidence = last_conf ,
266+ last_token = last_tok ,
267+ last_token_gap = last_gap ,
268+ last_two_phoneme_avg_confidence = last_two_avg ,
269+ token_segments = token_segments ,
183270 )
184271
185272
@@ -200,3 +287,9 @@ def classify(
200287 print (f"trail_rms_ratio: { result .trail_rms_ratio :.4f} " )
201288 print (f"last_token_dur: { result .last_token_duration :.3f} s" )
202289 print (f"last_token_conf: { result .last_token_confidence :.3f} " )
290+ print (f"last_token_gap: { result .last_token_gap :.3f} s" )
291+ print (f"last_2_ph_avg_conf: { result .last_two_phoneme_avg_confidence :.3f} " )
292+ print (f"last_token: { result .last_token !r} " )
293+ print (f"\n Token segments ({ len (result .token_segments )} ):" )
294+ for seg in result .token_segments :
295+ print (f" { seg .token !r:<6} { seg .start :.3f} -{ seg .end :.3f} s dur={ seg .duration :.3f} s conf={ seg .confidence :.3f} " )
0 commit comments