Skip to content

Commit 31454fc

Browse files
committed
Add unit test for EOU classifier
Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com>
1 parent f891642 commit 31454fc

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import librosa
16+
import numpy as np
17+
import pytest
18+
19+
from nemo.collections.tts.metrics.eou_classifier import (
20+
EoUClassification,
21+
EoUClassifier,
22+
EoUType,
23+
TokenSegment,
24+
_ends_with_sibilant,
25+
)
26+
27+
# ---------------------------------------------------------------------------
28+
# TODO: Fill in (audio_path, text) pairs per EoU class.
29+
# Paths are relative to the repo root. Multiple examples per class are supported.
30+
# ---------------------------------------------------------------------------
31+
DATA_PATH = "/home/TestData/tts/eou_classifier_unit_test"
32+
_CLASSIFICATION_CASES: list[tuple[EoUType, str, str]] = [
33+
(EoUType.GOOD, f"{DATA_PATH}/rodney.wav", "Yes, it is quite amazing to watch and I love all of it."),
34+
(
35+
EoUType.CUTOFF,
36+
f"{DATA_PATH}/libritts_test_clean_1320_122612_000056_000003.wav",
37+
"Having reached within a few yards of the latter, he arose to his feet, silently and slowly.",
38+
),
39+
(EoUType.SILENCE, f"{DATA_PATH}/magpie_silence_wood.wav", "w o o d"),
40+
(EoUType.NOISE, f"{DATA_PATH}/magpie_noisy_yes.wav", "yes"),
41+
# this one starts looping the text at the end, should be detected as noise
42+
(
43+
EoUType.NOISE,
44+
f"{DATA_PATH}/magpie_repeated_tail.wav",
45+
"Put them away quick before Andella and Rosalie see them.",
46+
),
47+
]
48+
49+
50+
@pytest.fixture(scope="module")
51+
def classifier():
52+
"""Load the Wav2Vec2 model once for the entire test module."""
53+
return EoUClassifier()
54+
55+
56+
# ── classification tests (one per class) ──────────────────────────────────
57+
58+
59+
@pytest.mark.unit
60+
@pytest.mark.parametrize(
61+
"eou_type, audio_path, text", _CLASSIFICATION_CASES, ids=[p for _, p, _ in _CLASSIFICATION_CASES]
62+
)
63+
def test_classification_matches_expected_class(classifier, eou_type, audio_path, text):
64+
"""Each sample should be classified as its expected EoU type."""
65+
result = classifier.classify(audio_path, text)
66+
67+
assert isinstance(result, EoUClassification)
68+
assert result.eou_type == eou_type, (
69+
f"Expected {eou_type.value!r} but got {result.eou_type.value!r} "
70+
f"(trailing={result.trailing_duration:.3f}s, rms_ratio={result.trail_rms_ratio:.4f}, "
71+
f"last_conf={result.last_token_confidence:.3f})"
72+
)
73+
74+
75+
# ── numpy array input ─────────────────────────────────────────────────────
76+
77+
78+
@pytest.mark.unit
79+
def test_classify_accepts_numpy_array(classifier):
80+
"""Classifier should accept a pre-loaded numpy array instead of a path."""
81+
_, audio_path, text = next(c for c in _CLASSIFICATION_CASES if c[0] == EoUType.GOOD)
82+
samples, _ = librosa.load(audio_path, sr=16000)
83+
84+
result_from_path = classifier.classify(audio_path, text)
85+
result_from_array = classifier.classify(samples, text)
86+
87+
assert result_from_path.eou_type == result_from_array.eou_type
88+
assert abs(result_from_path.trailing_duration - result_from_array.trailing_duration) < 1e-4
89+
90+
91+
# ── return value structure ────────────────────────────────────────────────
92+
93+
94+
@pytest.mark.unit
95+
def test_classification_result_structure(classifier):
96+
"""Verify the returned dataclass fields have correct types and reasonable ranges."""
97+
_, audio_path, text = next(c for c in _CLASSIFICATION_CASES if c[0] == EoUType.GOOD)
98+
result = classifier.classify(audio_path, text)
99+
100+
assert isinstance(result.eou_type, EoUType)
101+
assert result.speech_end >= 0.0
102+
assert result.audio_duration > 0.0
103+
assert result.trailing_duration >= 0.0
104+
assert result.speech_end <= result.audio_duration + 0.5 # small tolerance for frame rounding
105+
assert 0.0 <= result.trail_rms_ratio
106+
assert result.last_token_duration >= 0.0
107+
assert 0.0 <= result.last_token_confidence <= 1.0
108+
assert isinstance(result.last_token, str)
109+
assert result.last_token_gap >= 0.0
110+
assert 0.0 <= result.last_two_phoneme_avg_confidence <= 1.0
111+
112+
assert isinstance(result.token_segments, list)
113+
assert len(result.token_segments) > 0
114+
for seg in result.token_segments:
115+
assert isinstance(seg, TokenSegment)
116+
assert seg.end >= seg.start
117+
assert seg.duration >= 0.0
118+
assert 0.0 <= seg.confidence <= 1.0

0 commit comments

Comments
 (0)