Skip to content

Commit 9122af1

Browse files
committed
Add copyright headers
Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com>
1 parent acf0c73 commit 9122af1

File tree

2 files changed

+61
-18
lines changed

2 files changed

+61
-18
lines changed

nemo/collections/tts/metrics/eou_classifier.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
"""
216
Classify the end-of-utterance (EoU) audio as: good (natural ending), cutoff (abrupt
317
ending), silence (long trailing region that is quiet), or noise (significant trailing
@@ -15,6 +29,7 @@
1529
print(result.eou_type, result.trailing_duration)
1630
"""
1731

32+
import math
1833
from dataclasses import dataclass, field
1934
from enum import StrEnum
2035
from typing import Union
@@ -48,6 +63,11 @@ class EoUType(StrEnum):
4863
SILENCE = "silence" # long trailing region with near-zero energy
4964
NOISE = "noise" # significant trailing region with high energy relative to speech
5065

66+
@classmethod
67+
def error_types(cls) -> tuple["EoUType", ...]:
68+
"""All types that represent an error (everything except GOOD)."""
69+
return tuple(t for t in cls if t != cls.GOOD)
70+
5171

5272
@dataclass
5373
class TokenSegment:
@@ -89,44 +109,57 @@ def __init__(self, model_name: str = "facebook/wav2vec2-base-960h", sr: int = SR
89109
self.blank_id = self.processor.tokenizer.pad_token_id
90110
self.vocab = self.processor.tokenizer.get_vocab()
91111
self.id_to_token = {v: k for k, v in self.vocab.items()}
112+
self.frame_duration = math.prod(self.model.config.conv_stride) / self.sr
92113

93114
def _text_to_tokens(self, text: str) -> list[int]:
115+
# Wav2Vec2 uses uppercase characters; normalize to match its vocabulary
94116
text = text.upper().strip()
95117
tokens = []
96118
for i, word in enumerate(text.split()):
119+
# "|" is the word-boundary token in Wav2Vec2's CTC vocabulary
97120
if i > 0:
98121
tokens.append(self.vocab["|"])
99122
for char in word:
123+
# Skip characters not in vocab (punctuation, accents, etc.)
100124
if char in self.vocab:
101125
tokens.append(self.vocab[char])
102126
return tokens
103127

104-
def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
128+
def _forced_align(self, audio: np.ndarray, text: str) -> dict:
105129
"""Run forced alignment and return speech boundary info."""
130+
# Tokenize audio into Wav2Vec2 input features
106131
input_values = self.processor(audio, return_tensors="pt", sampling_rate=self.sr).input_values
107132

133+
# Forward pass through the CTC model to get per-frame logits
108134
with torch.no_grad():
109135
logits = self.model(input_values).logits[0]
110136

111137
log_probs = torch.log_softmax(logits, dim=-1)
112-
n_frames = len(logits)
113-
frame_duration = len(audio) / n_frames / self.sr
138+
frame_duration = self.frame_duration
114139

140+
# Convert target text to CTC token IDs and run torchaudio forced alignment
115141
target_tokens = self._text_to_tokens(text)
116142
fa_ids, fa_scores = taf.forced_align(
117143
log_probs.unsqueeze(0),
118144
torch.tensor([target_tokens]),
119145
blank=self.blank_id,
120146
)
147+
# fa_ids: per-frame aligned token IDs (blank or target token)
148+
# fa_scores: per-frame log-prob scores; convert to probabilities for confidence
121149
aligned_ids = fa_ids[0].numpy()
122150
scores = torch.exp(fa_scores[0]).numpy()
123151

152+
# Walk through the frame-level alignment and merge consecutive
153+
# frames of the same token into TokenSegment objects.
154+
# Transitions: blank-to-token (start new segment), token-to-blank (end segment),
155+
# token-to-different token (end old, start new).
124156
segments: list[TokenSegment] = []
125-
cur_id = -1
157+
cur_id = -1 # indicates that no segment is open
126158
seg_start = 0
127159
for i, aid in enumerate(aligned_ids):
128160
tid = int(aid)
129161
if tid == self.blank_id:
162+
# Blank frame: close the current segment if one is open
130163
if cur_id != -1:
131164
seg_scores = scores[seg_start:i]
132165
segments.append(
@@ -140,6 +173,7 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
140173
)
141174
cur_id = -1
142175
elif tid != cur_id:
176+
# New non-blank token: close previous segment (if any) and start a new one
143177
if cur_id != -1:
144178
seg_scores = scores[seg_start:i]
145179
segments.append(
@@ -153,7 +187,8 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
153187
)
154188
cur_id = tid
155189
seg_start = i
156-
# else: same non-blank token continues
190+
# else: same non-blank token continues — keep extending the segment
191+
# Flush the last open segment if the alignment ends on a non-blank token
157192
if cur_id != -1:
158193
seg_scores = scores[seg_start : len(aligned_ids)]
159194
segments.append(
@@ -166,6 +201,7 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
166201
)
167202
)
168203

204+
# No tokens were aligned — return zeroed-out defaults
169205
if not segments:
170206
return {
171207
"speech_end": 0.0,
@@ -188,17 +224,22 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
188224
last_speech = seg
189225
break
190226

227+
# Measure the blank gap between the last speech token and its predecessor.
228+
# A large gap can indicate noise or misalignment before the final sound.
191229
last_idx = segments.index(last_speech)
192230
if last_idx > 0:
193231
last_token_gap = last_speech.start - segments[last_idx - 1].end
194232
else:
233+
# First (and only) token — gap is measured from audio start
195234
last_token_gap = last_speech.start
196235

236+
# Average confidence of the last two alphanumeric tokens;
237+
# used as a fallback when the single last-token confidence is near zero.
197238
last_two_alnum = [s for s in segments if s.token.isalnum()][-2:]
198239
last_two_avg = float(np.mean([s.confidence for s in last_two_alnum]))
199240

200241
return {
201-
"speech_end": last.end, # + 0.05, # add 50ms of tolerance
242+
"speech_end": last.end,
202243
"last_token_duration": last_speech.duration,
203244
"last_token_confidence": last_speech.confidence,
204245
"last_token": last_speech.token,
@@ -213,7 +254,7 @@ def classify(
213254
text: str,
214255
) -> EoUClassification:
215256
"""
216-
Classify the end-of-utterance quality of a TTS audio sample.
257+
Classify the end-of-utterance quality of utterance audio.
217258
218259
Args:
219260
audio: Path to a WAV file, or a numpy array of audio samples at self.sr.
@@ -222,19 +263,24 @@ def classify(
222263
Returns:
223264
EoUClassification with the predicted eou_type and supporting features.
224265
"""
266+
# Accept either a file path or a pre-loaded numpy array
225267
if isinstance(audio, np.ndarray):
226268
samples = audio
227269
else:
228270
samples, _ = librosa.load(audio, sr=self.sr)
229271

230272
audio_dur = len(samples) / self.sr
231-
info = self._find_speech_end(samples, text)
273+
# Run forced alignment and collect information about speech segments
274+
info = self._forced_align(samples, text)
232275

233276
speech_end = info["speech_end"]
234277
trailing = audio_dur - speech_end
235278
last_letter_pad = 0.15 if _ends_with_sibilant(text) else 0.1
236279
trail_start = int((speech_end + last_letter_pad) * self.sr)
237280
trailing_audio = samples[trail_start:]
281+
282+
# Compute RMS energy ratio between the trailing region and the full
283+
# utterance — a high ratio means the tail is loud
238284
if len(trailing_audio) > 0:
239285
rms_trail = np.sqrt(np.mean(trailing_audio**2))
240286
rms_full = np.sqrt(np.mean(samples**2))
@@ -251,22 +297,19 @@ def classify(
251297
last_two_avg = info["last_two_phoneme_avg_confidence"]
252298
token_segments = info["token_segments"]
253299

254-
# if trailing < 0.06 and (last_dur < 0.025 and last_conf < 0.15):
255-
# if trailing < 0.06 and (last_gap < 0.1 and last_conf < 0.15): short trail and not due to gap
300+
# --- Decision tree for EoU classification ---
256301
conf_threshold = 0.07
257-
# short tail with low confidence and not due to gap (which could indicate noise) --> cutoff
302+
# Short tail with low confidence and not due to gap (which could indicate noise) --> cutoff
258303
if trailing < 0.1 and last_conf < conf_threshold and not last_gap > 0.4:
259-
# speech ends abruptly, with a very short last token or low confidence --> cutoff
260304
eou_type = EoUType.CUTOFF
261-
# long noisy tail OR odd gap --> noisy
305+
# Long noisy tail OR a gap between last two segements and low confidence --> noisy
262306
elif (trailing > 0.15 and trail_rms_ratio > 0.4) or (last_gap > 0.4 and last_conf < 0.15):
263-
# significant trailing region with high energy relative to speech --> noise
264307
eou_type = EoUType.NOISE
265-
# very long trailing region with near-zero energy --> silence
266-
elif trailing > 1.0: # and trail_rms_ratio < 0.10:
308+
# Long tail without much energy (or it would captured by the previous condition) --> silence
309+
elif trailing > 1.0:
267310
eou_type = EoUType.SILENCE
268311
else:
269-
# everything else (moderate trailing, natural energy decay) --> good
312+
# everything else --> good
270313
eou_type = EoUType.GOOD
271314

272315
return EoUClassification(

tests/collections/tts/metrics/test_eou_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)