1+ # Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
115"""
216Classify the end-of-utterance (EoU) audio as: good (natural ending), cutoff (abrupt
317ending), silence (long trailing region that is quiet), or noise (significant trailing
1529 print(result.eou_type, result.trailing_duration)
1630"""
1731
32+ import math
1833from dataclasses import dataclass , field
1934from enum import StrEnum
2035from typing import Union
@@ -48,6 +63,11 @@ class EoUType(StrEnum):
4863 SILENCE = "silence" # long trailing region with near-zero energy
4964 NOISE = "noise" # significant trailing region with high energy relative to speech
5065
66+ @classmethod
67+ def error_types (cls ) -> tuple ["EoUType" , ...]:
68+ """All types that represent an error (everything except GOOD)."""
69+ return tuple (t for t in cls if t != cls .GOOD )
70+
5171
5272@dataclass
5373class TokenSegment :
@@ -89,44 +109,57 @@ def __init__(self, model_name: str = "facebook/wav2vec2-base-960h", sr: int = SR
89109 self .blank_id = self .processor .tokenizer .pad_token_id
90110 self .vocab = self .processor .tokenizer .get_vocab ()
91111 self .id_to_token = {v : k for k , v in self .vocab .items ()}
112+ self .frame_duration = math .prod (self .model .config .conv_stride ) / self .sr
92113
93114 def _text_to_tokens (self , text : str ) -> list [int ]:
115+ # Wav2Vec2 uses uppercase characters; normalize to match its vocabulary
94116 text = text .upper ().strip ()
95117 tokens = []
96118 for i , word in enumerate (text .split ()):
119+ # "|" is the word-boundary token in Wav2Vec2's CTC vocabulary
97120 if i > 0 :
98121 tokens .append (self .vocab ["|" ])
99122 for char in word :
123+ # Skip characters not in vocab (punctuation, accents, etc.)
100124 if char in self .vocab :
101125 tokens .append (self .vocab [char ])
102126 return tokens
103127
104- def _find_speech_end (self , audio : np .ndarray , text : str ) -> dict :
128+ def _forced_align (self , audio : np .ndarray , text : str ) -> dict :
105129 """Run forced alignment and return speech boundary info."""
130+ # Tokenize audio into Wav2Vec2 input features
106131 input_values = self .processor (audio , return_tensors = "pt" , sampling_rate = self .sr ).input_values
107132
133+ # Forward pass through the CTC model to get per-frame logits
108134 with torch .no_grad ():
109135 logits = self .model (input_values ).logits [0 ]
110136
111137 log_probs = torch .log_softmax (logits , dim = - 1 )
112- n_frames = len (logits )
113- frame_duration = len (audio ) / n_frames / self .sr
138+ frame_duration = self .frame_duration
114139
140+ # Convert target text to CTC token IDs and run torchaudio forced alignment
115141 target_tokens = self ._text_to_tokens (text )
116142 fa_ids , fa_scores = taf .forced_align (
117143 log_probs .unsqueeze (0 ),
118144 torch .tensor ([target_tokens ]),
119145 blank = self .blank_id ,
120146 )
147+ # fa_ids: per-frame aligned token IDs (blank or target token)
148+ # fa_scores: per-frame log-prob scores; convert to probabilities for confidence
121149 aligned_ids = fa_ids [0 ].numpy ()
122150 scores = torch .exp (fa_scores [0 ]).numpy ()
123151
152+ # Walk through the frame-level alignment and merge consecutive
153+ # frames of the same token into TokenSegment objects.
154+ # Transitions: blank-to-token (start new segment), token-to-blank (end segment),
155+ # token-to-different token (end old, start new).
124156 segments : list [TokenSegment ] = []
125- cur_id = - 1
157+ cur_id = - 1 # indicates that no segment is open
126158 seg_start = 0
127159 for i , aid in enumerate (aligned_ids ):
128160 tid = int (aid )
129161 if tid == self .blank_id :
162+ # Blank frame: close the current segment if one is open
130163 if cur_id != - 1 :
131164 seg_scores = scores [seg_start :i ]
132165 segments .append (
@@ -140,6 +173,7 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
140173 )
141174 cur_id = - 1
142175 elif tid != cur_id :
176+ # New non-blank token: close previous segment (if any) and start a new one
143177 if cur_id != - 1 :
144178 seg_scores = scores [seg_start :i ]
145179 segments .append (
@@ -153,7 +187,8 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
153187 )
154188 cur_id = tid
155189 seg_start = i
156- # else: same non-blank token continues
190+ # else: same non-blank token continues — keep extending the segment
191+ # Flush the last open segment if the alignment ends on a non-blank token
157192 if cur_id != - 1 :
158193 seg_scores = scores [seg_start : len (aligned_ids )]
159194 segments .append (
@@ -166,6 +201,7 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
166201 )
167202 )
168203
204+ # No tokens were aligned — return zeroed-out defaults
169205 if not segments :
170206 return {
171207 "speech_end" : 0.0 ,
@@ -188,17 +224,22 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
188224 last_speech = seg
189225 break
190226
227+ # Measure the blank gap between the last speech token and its predecessor.
228+ # A large gap can indicate noise or misalignment before the final sound.
191229 last_idx = segments .index (last_speech )
192230 if last_idx > 0 :
193231 last_token_gap = last_speech .start - segments [last_idx - 1 ].end
194232 else :
233+ # First (and only) token — gap is measured from audio start
195234 last_token_gap = last_speech .start
196235
236+ # Average confidence of the last two alphanumeric tokens;
237+ # used as a fallback when the single last-token confidence is near zero.
197238 last_two_alnum = [s for s in segments if s .token .isalnum ()][- 2 :]
198239 last_two_avg = float (np .mean ([s .confidence for s in last_two_alnum ]))
199240
200241 return {
201- "speech_end" : last .end , # + 0.05, # add 50ms of tolerance
242+ "speech_end" : last .end ,
202243 "last_token_duration" : last_speech .duration ,
203244 "last_token_confidence" : last_speech .confidence ,
204245 "last_token" : last_speech .token ,
@@ -213,7 +254,7 @@ def classify(
213254 text : str ,
214255 ) -> EoUClassification :
215256 """
216- Classify the end-of-utterance quality of a TTS audio sample .
257+ Classify the end-of-utterance quality of utterance audio.
217258
218259 Args:
219260 audio: Path to a WAV file, or a numpy array of audio samples at self.sr.
@@ -222,19 +263,24 @@ def classify(
222263 Returns:
223264 EoUClassification with the predicted eou_type and supporting features.
224265 """
266+ # Accept either a file path or a pre-loaded numpy array
225267 if isinstance (audio , np .ndarray ):
226268 samples = audio
227269 else :
228270 samples , _ = librosa .load (audio , sr = self .sr )
229271
230272 audio_dur = len (samples ) / self .sr
231- info = self ._find_speech_end (samples , text )
273+ # Run forced alignment and collect information about speech segments
274+ info = self ._forced_align (samples , text )
232275
233276 speech_end = info ["speech_end" ]
234277 trailing = audio_dur - speech_end
235278 last_letter_pad = 0.15 if _ends_with_sibilant (text ) else 0.1
236279 trail_start = int ((speech_end + last_letter_pad ) * self .sr )
237280 trailing_audio = samples [trail_start :]
281+
282+ # Compute RMS energy ratio between the trailing region and the full
283+ # utterance — a high ratio means the tail is loud
238284 if len (trailing_audio ) > 0 :
239285 rms_trail = np .sqrt (np .mean (trailing_audio ** 2 ))
240286 rms_full = np .sqrt (np .mean (samples ** 2 ))
@@ -251,22 +297,19 @@ def classify(
251297 last_two_avg = info ["last_two_phoneme_avg_confidence" ]
252298 token_segments = info ["token_segments" ]
253299
254- # if trailing < 0.06 and (last_dur < 0.025 and last_conf < 0.15):
255- # if trailing < 0.06 and (last_gap < 0.1 and last_conf < 0.15): short trail and not due to gap
300+ # --- Decision tree for EoU classification ---
256301 conf_threshold = 0.07
257- # short tail with low confidence and not due to gap (which could indicate noise) --> cutoff
302+ # Short tail with low confidence and not due to gap (which could indicate noise) --> cutoff
258303 if trailing < 0.1 and last_conf < conf_threshold and not last_gap > 0.4 :
259- # speech ends abruptly, with a very short last token or low confidence --> cutoff
260304 eou_type = EoUType .CUTOFF
261- # long noisy tail OR odd gap --> noisy
305+ # Long noisy tail OR a gap between last two segements and low confidence --> noisy
262306 elif (trailing > 0.15 and trail_rms_ratio > 0.4 ) or (last_gap > 0.4 and last_conf < 0.15 ):
263- # significant trailing region with high energy relative to speech --> noise
264307 eou_type = EoUType .NOISE
265- # very long trailing region with near-zero energy --> silence
266- elif trailing > 1.0 : # and trail_rms_ratio < 0.10:
308+ # Long tail without much energy (or it would captured by the previous condition) --> silence
309+ elif trailing > 1.0 :
267310 eou_type = EoUType .SILENCE
268311 else :
269- # everything else (moderate trailing, natural energy decay) --> good
312+ # everything else --> good
270313 eou_type = EoUType .GOOD
271314
272315 return EoUClassification (
0 commit comments