Skip to content

Commit 20f19f9

Browse files
committed
Add back filewise metric filtering
It has been moved outside of evaluate() since in the NeMo Skills use case we need the full metrics for chunk-wise scoring and aggregation at the end. Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com>
1 parent 4589a8c commit 20f19f9

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

examples/tts/magpietts_inference.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,21 @@ def run_inference_and_evaluation(
308308
with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f:
309309
json.dump(metrics, f, indent=4)
310310

311-
# Sort by CER descending for human-readable output (highest error first)
312-
sorted_filewise = sorted(filewise_metrics, key=lambda x: x.get('cer', 0), reverse=True)
311+
filewise_metrics_keys_to_save = [
312+
'cer',
313+
'wer',
314+
'pred_context_ssim',
315+
'pred_text',
316+
'gt_text',
317+
'gt_audio_filepath',
318+
'pred_audio_filepath',
319+
'context_audio_filepath',
320+
'utmosv2',
321+
]
322+
filtered_filewise = [{k: m[k] for k in filewise_metrics_keys_to_save if k in m} for m in filewise_metrics]
323+
filtered_filewise.sort(key=lambda x: x.get('cer', 0), reverse=True)
313324
with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f:
314-
json.dump(sorted_filewise, f, indent=4)
325+
json.dump(filtered_filewise, f, indent=4)
315326

316327
# Append to per-run CSV
317328
append_metrics_to_csv(per_run_csv, full_checkpoint_name, dataset, metrics)

nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""
15-
Used in infer_and_evaluate.py to obtain metrics such as ASR_WER and UTMOSV2 scores.
15+
Used in inference and evaluation scripts to obtain metrics such as ASR_WER and UTMOSV2 scores.
1616
"""
1717
import argparse
1818
import json

0 commit comments

Comments
 (0)