Skip to content

Commit 046cc5c

Browse files
committed
EOU algorithm improvements
Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com>
1 parent 0652fbe commit 046cc5c

File tree

1 file changed

+121
-28
lines changed

1 file changed

+121
-28
lines changed

nemo/collections/tts/metrics/eou_classifier.py

Lines changed: 121 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
print(result.eou_type, result.trailing_duration)
1616
"""
1717

18-
from dataclasses import dataclass
18+
from dataclasses import dataclass, field
1919
from enum import StrEnum
2020
from typing import Union
2121

@@ -35,6 +35,15 @@ class EoUType(StrEnum):
3535
NOISE = "noise" # significant trailing region with high energy relative to speech
3636

3737

38+
@dataclass
39+
class TokenSegment:
40+
token: str
41+
start: float # seconds
42+
end: float # seconds
43+
duration: float # seconds
44+
confidence: float
45+
46+
3847
@dataclass
3948
class EoUClassification:
4049
eou_type: EoUType
@@ -44,6 +53,10 @@ class EoUClassification:
4453
trail_rms_ratio: float
4554
last_token_duration: float
4655
last_token_confidence: float
56+
last_token: str
57+
last_token_gap: float # blank gap (seconds) between last and second-to-last speech token
58+
last_two_phoneme_avg_confidence: float # average confidence of last two alphanumeric tokens
59+
token_segments: list[TokenSegment] = field(default_factory=list)
4760

4861

4962
class EoUClassifier:
@@ -61,6 +74,7 @@ def __init__(self, model_name: str = "facebook/wav2vec2-base-960h", sr: int = SR
6174
self.model.eval()
6275
self.blank_id = self.processor.tokenizer.pad_token_id
6376
self.vocab = self.processor.tokenizer.get_vocab()
77+
self.id_to_token = {v: k for k, v in self.vocab.items()}
6478

6579
def _text_to_tokens(self, text: str) -> list[int]:
6680
text = text.upper().strip()
@@ -93,32 +107,90 @@ def _find_speech_end(self, audio: np.ndarray, text: str) -> dict:
93107
aligned_ids = fa_ids[0].numpy()
94108
scores = torch.exp(fa_scores[0]).numpy()
95109

96-
speech_end_frame = 0
97-
last_token_start = 0
98-
last_token_id = -1
99-
for i in range(len(aligned_ids) - 1, -1, -1):
100-
if aligned_ids[i] != self.blank_id:
101-
if last_token_id == -1:
102-
speech_end_frame = i + 1
103-
last_token_id = int(aligned_ids[i])
104-
last_token_start = i
105-
elif int(aligned_ids[i]) != last_token_id:
106-
break
107-
else:
108-
last_token_start = i
109-
110-
if last_token_id == -1:
110+
segments: list[TokenSegment] = []
111+
cur_id = -1
112+
seg_start = 0
113+
for i, aid in enumerate(aligned_ids):
114+
tid = int(aid)
115+
if tid == self.blank_id:
116+
if cur_id != -1:
117+
seg_scores = scores[seg_start:i]
118+
segments.append(
119+
TokenSegment(
120+
token=self.id_to_token.get(cur_id, f"<id:{cur_id}>"),
121+
start=seg_start * frame_duration,
122+
end=i * frame_duration,
123+
duration=(i - seg_start) * frame_duration,
124+
confidence=float(seg_scores.mean()),
125+
)
126+
)
127+
cur_id = -1
128+
elif tid != cur_id:
129+
if cur_id != -1:
130+
seg_scores = scores[seg_start:i]
131+
segments.append(
132+
TokenSegment(
133+
token=self.id_to_token.get(cur_id, f"<id:{cur_id}>"),
134+
start=seg_start * frame_duration,
135+
end=i * frame_duration,
136+
duration=(i - seg_start) * frame_duration,
137+
confidence=float(seg_scores.mean()),
138+
)
139+
)
140+
cur_id = tid
141+
seg_start = i
142+
# else: same non-blank token continues
143+
if cur_id != -1:
144+
seg_scores = scores[seg_start : len(aligned_ids)]
145+
segments.append(
146+
TokenSegment(
147+
token=self.id_to_token.get(cur_id, f"<id:{cur_id}>"),
148+
start=seg_start * frame_duration,
149+
end=len(aligned_ids) * frame_duration,
150+
duration=(len(aligned_ids) - seg_start) * frame_duration,
151+
confidence=float(seg_scores.mean()),
152+
)
153+
)
154+
155+
if not segments:
111156
return {
112157
"speech_end": 0.0,
113158
"last_token_duration": 0.0,
114159
"last_token_confidence": 0.0,
160+
"last_token": "",
161+
"last_token_gap": 0.0,
162+
"last_two_phoneme_avg_confidence": 0.0,
163+
"token_segments": [],
115164
}
116165

117-
last_seg_scores = scores[last_token_start:speech_end_frame]
166+
last = segments[-1]
167+
168+
# Skip trailing punctuation/non-letter tokens for cutoff analysis,
169+
# since they don't correspond to real speech sounds and get
170+
# unreliably short durations from forced alignment.
171+
last_speech = last
172+
for seg in reversed(segments):
173+
if seg.token.isalnum():
174+
last_speech = seg
175+
break
176+
177+
last_idx = segments.index(last_speech)
178+
if last_idx > 0:
179+
last_token_gap = last_speech.start - segments[last_idx - 1].end
180+
else:
181+
last_token_gap = last_speech.start
182+
183+
last_two_alnum = [s for s in segments if s.token.isalnum()][-2:]
184+
last_two_avg = float(np.mean([s.confidence for s in last_two_alnum]))
185+
118186
return {
119-
"speech_end": speech_end_frame * frame_duration,
120-
"last_token_duration": (speech_end_frame - last_token_start) * frame_duration,
121-
"last_token_confidence": float(last_seg_scores.mean()),
187+
"speech_end": last.end, # + 0.05, # add 50ms of tolerance
188+
"last_token_duration": last_speech.duration,
189+
"last_token_confidence": last_speech.confidence,
190+
"last_token": last_speech.token,
191+
"last_token_gap": last_token_gap,
192+
"last_two_phoneme_avg_confidence": last_two_avg,
193+
"token_segments": segments,
122194
}
123195

124196
def classify(
@@ -146,8 +218,8 @@ def classify(
146218

147219
speech_end = info["speech_end"]
148220
trailing = audio_dur - speech_end
149-
150-
trail_start = int(speech_end * self.sr)
221+
last_letter_pad = 0.15
222+
trail_start = int((speech_end + last_letter_pad) * self.sr)
151223
trailing_audio = samples[trail_start:]
152224
if len(trailing_audio) > 0:
153225
rms_trail = np.sqrt(np.mean(trailing_audio**2))
@@ -158,15 +230,26 @@ def classify(
158230

159231
last_dur = info["last_token_duration"]
160232
last_conf = info["last_token_confidence"]
161-
162-
if trailing < 0.06 and last_dur < 0.025 and last_conf < 0.1:
163-
# speech ends abruptly, with a very short last token and low confidence --> cutoff
233+
if last_conf < 0.01:
234+
last_conf = info["last_two_phoneme_avg_confidence"]
235+
last_tok = info["last_token"]
236+
last_gap = info["last_token_gap"]
237+
last_two_avg = info["last_two_phoneme_avg_confidence"]
238+
token_segments = info["token_segments"]
239+
240+
# if trailing < 0.06 and (last_dur < 0.025 and last_conf < 0.15):
241+
# if trailing < 0.06 and (last_gap < 0.1 and last_conf < 0.15): short trail and not due to gap
242+
conf_threshold = 0.07
243+
# short tail with low confidence and not due to gap (which could indicate noise) --> cutoff
244+
if trailing < 0.1 and last_conf < conf_threshold and not last_gap > 0.4:
245+
# speech ends abruptly, with a very short last token or low confidence --> cutoff
164246
eou_type = EoUType.CUTOFF
165-
elif trailing > 0.3 and trail_rms_ratio > 0.5:
247+
# long noisy tail OR odd gap --> noisy
248+
elif (trailing > 0.25 and trail_rms_ratio > 0.3) or (last_gap > 0.4 and last_conf < conf_threshold):
166249
# significant trailing region with high energy relative to speech --> noise
167250
eou_type = EoUType.NOISE
168-
elif trailing > 1.0 and trail_rms_ratio < 0.10:
169-
# very long trailing region with near-zero energy --> silence
251+
# very long trailing region with near-zero energy --> silence
252+
elif trailing > 1.0: # and trail_rms_ratio < 0.10:
170253
eou_type = EoUType.SILENCE
171254
else:
172255
# everything else (moderate trailing, natural energy decay) --> good
@@ -180,6 +263,10 @@ def classify(
180263
trail_rms_ratio=trail_rms_ratio,
181264
last_token_duration=last_dur,
182265
last_token_confidence=last_conf,
266+
last_token=last_tok,
267+
last_token_gap=last_gap,
268+
last_two_phoneme_avg_confidence=last_two_avg,
269+
token_segments=token_segments,
183270
)
184271

185272

@@ -200,3 +287,9 @@ def classify(
200287
print(f"trail_rms_ratio: {result.trail_rms_ratio:.4f}")
201288
print(f"last_token_dur: {result.last_token_duration:.3f}s")
202289
print(f"last_token_conf: {result.last_token_confidence:.3f}")
290+
print(f"last_token_gap: {result.last_token_gap:.3f}s")
291+
print(f"last_2_ph_avg_conf: {result.last_two_phoneme_avg_confidence:.3f}")
292+
print(f"last_token: {result.last_token!r}")
293+
print(f"\nToken segments ({len(result.token_segments)}):")
294+
for seg in result.token_segments:
295+
print(f" {seg.token!r:<6} {seg.start:.3f}-{seg.end:.3f}s dur={seg.duration:.3f}s conf={seg.confidence:.3f}")

0 commit comments

Comments
 (0)