Skip to content

Commit 1129a59

Browse files
authored
ComputeTimeStampErrorJob: allow multiple caches as input, add output file with individual TSEs per segment pair (#605)
* `ComputeTimeStampErrorJob`: also accept list for ref/hyp caches * Modify code to adapt to list format * Add output of sorted highest TSE differences for analyzed segments * Update imports * Uniformize output Sad but needed * Remove "highest" prefix from output file * Rename self parameter to plural Technically it's a list of alignment caches
1 parent 69f1e08 commit 1129a59

File tree

1 file changed

+104
-91
lines changed

1 file changed

+104
-91
lines changed

mm/alignment.py

Lines changed: 104 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,8 @@ class ComputeTimeStampErrorJob(Job):
562562

563563
def __init__(
564564
self,
565-
hyp_alignment_cache: tk.Path,
566-
ref_alignment_cache: tk.Path,
565+
hyp_alignment_cache: Union[tk.Path, List[tk.Path]],
566+
ref_alignment_cache: Union[tk.Path, List[tk.Path]],
567567
hyp_allophone_file: tk.Path,
568568
ref_allophone_file: tk.Path,
569569
hyp_silence_phone: str = "[SILENCE]",
@@ -585,12 +585,16 @@ def __init__(
585585
:param hyp_seq_tag_transform: Function that transforms seq tag in alignment cache such that it matches the seq tags in the reference
586586
:param remove_outlier_limit: If set, boundary differences greater than this frame limit are discarded from computation
587587
"""
588-
self.hyp_alignment_cache = hyp_alignment_cache
588+
self.hyp_alignment_caches = (
589+
hyp_alignment_cache if isinstance(hyp_alignment_cache, List) else [hyp_alignment_cache]
590+
)
589591
self.hyp_allophone_file = hyp_allophone_file
590592
self.hyp_silence_phone = hyp_silence_phone
591593
self.hyp_upsample_factor = hyp_upsample_factor
592594

593-
self.ref_alignment_cache = ref_alignment_cache
595+
self.ref_alignment_caches = (
596+
ref_alignment_cache if isinstance(ref_alignment_cache, List) else [ref_alignment_cache]
597+
)
594598
self.ref_allophone_file = ref_allophone_file
595599
self.ref_silence_phone = ref_silence_phone
596600
self.ref_upsample_factor = ref_upsample_factor
@@ -605,6 +609,7 @@ def __init__(
605609
self.out_plot_word_end_frame_differences = self.output_path("end_frame_differences.png")
606610
self.out_boundary_frame_differences = self.output_var("boundary_frame_differences")
607611
self.out_plot_boundary_frame_differences = self.output_path("boundary_frame_differences.png")
612+
self.out_tse_differences_file = self.output_path("tse_differences.txt")
608613

609614
self.rqmt = None
610615

@@ -660,102 +665,105 @@ def run(self) -> None:
660665
start_differences = Counter()
661666
end_differences = Counter()
662667
differences = Counter()
668+
tse_dict: Dict[str, Tuple[str, float]] = {} # ref_seg_name: (hyp_seg_name, avg_tse)
663669

664-
hyp_alignments = rasr_cache.open_file_archive(self.hyp_alignment_cache.get())
665-
hyp_alignments.setAllophones(self.hyp_allophone_file.get())
666-
if isinstance(hyp_alignments, rasr_cache.FileArchiveBundle):
667-
hyp_allophone_map = next(iter(hyp_alignments.archives.values())).allophones
668-
else:
669-
hyp_allophone_map = hyp_alignments.allophones
670-
671-
ref_alignments = rasr_cache.open_file_archive(self.ref_alignment_cache.get())
672-
ref_alignments.setAllophones(self.ref_allophone_file.get())
673-
if isinstance(ref_alignments, rasr_cache.FileArchiveBundle):
674-
ref_allophone_map = next(iter(ref_alignments.archives.values())).allophones
675-
else:
676-
ref_allophone_map = ref_alignments.allophones
677-
678-
file_list = [tag for tag in hyp_alignments.file_list() if not tag.endswith(".attribs")]
679-
680-
for idx, hyp_seq_tag in enumerate(file_list, start=1):
681-
hyp_word_starts, hyp_word_ends, hyp_seq_length = self._compute_word_boundaries(
682-
hyp_alignments,
683-
hyp_allophone_map,
684-
hyp_seq_tag,
685-
self.hyp_silence_phone,
686-
self.hyp_upsample_factor,
687-
)
688-
assert len(hyp_word_starts) == len(hyp_word_ends), (
689-
f"Found different number of word starts ({len(hyp_word_starts)}) "
690-
f"than word ends ({len(hyp_word_ends)}). Something seems to be broken."
691-
)
670+
for hyp_alignment_cache, ref_alignment_cache in zip(self.hyp_alignment_caches, self.ref_alignment_caches):
671+
hyp_alignments = rasr_cache.open_file_archive(hyp_alignment_cache.get())
672+
hyp_alignments.setAllophones(self.hyp_allophone_file.get())
673+
if isinstance(hyp_alignments, rasr_cache.FileArchiveBundle):
674+
hyp_allophone_map = next(iter(hyp_alignments.archives.values())).allophones
675+
else:
676+
hyp_allophone_map = hyp_alignments.allophones
692677

693-
if self.hyp_seq_tag_transform is not None:
694-
ref_seq_tag = self.hyp_seq_tag_transform(hyp_seq_tag)
678+
ref_alignments = rasr_cache.open_file_archive(ref_alignment_cache.get())
679+
ref_alignments.setAllophones(self.ref_allophone_file.get())
680+
if isinstance(ref_alignments, rasr_cache.FileArchiveBundle):
681+
ref_allophone_map = next(iter(ref_alignments.archives.values())).allophones
695682
else:
696-
ref_seq_tag = hyp_seq_tag
697-
698-
ref_word_starts, ref_word_ends, ref_seq_length = self._compute_word_boundaries(
699-
ref_alignments,
700-
ref_allophone_map,
701-
ref_seq_tag,
702-
self.ref_silence_phone,
703-
self.ref_upsample_factor,
704-
)
705-
assert len(ref_word_starts) == len(ref_word_ends), (
706-
f"Found different number of word starts ({len(hyp_word_starts)}) "
707-
f"than word ends ({len(hyp_word_ends)}) in reference. Something seems to be broken."
708-
)
683+
ref_allophone_map = ref_alignments.allophones
709684

710-
if len(hyp_word_starts) != len(ref_word_starts):
711-
logging.warning(
712-
f"Sequence {hyp_seq_tag} ({idx} / {len(file_list)}:\n Discarded because the number of words in alignment ({len(hyp_word_starts)}) does not equal the number of words in reference ({len(ref_word_starts)})."
713-
)
714-
discarded_seqs += 1
715-
continue
685+
file_list = [tag for tag in hyp_alignments.file_list() if not tag.endswith(".attribs")]
716686

717-
# Sometimes different feature extraction or subsampling may produce mismatched lengths that are different by a few frames, so cut off at the shorter length
718-
shorter_seq_length = min(hyp_seq_length, ref_seq_length)
687+
for idx, hyp_seq_tag in enumerate(file_list, start=1):
688+
hyp_word_starts, hyp_word_ends, hyp_seq_length = self._compute_word_boundaries(
689+
hyp_alignments,
690+
hyp_allophone_map,
691+
hyp_seq_tag,
692+
self.hyp_silence_phone,
693+
self.hyp_upsample_factor,
694+
)
695+
assert len(hyp_word_starts) == len(hyp_word_ends), (
696+
f"Found different number of word starts ({len(hyp_word_starts)}) "
697+
f"than word ends ({len(hyp_word_ends)}). Something seems to be broken."
698+
)
719699

720-
for i in range(len(hyp_word_ends) - 1, 0, -1):
721-
if hyp_word_ends[i] > shorter_seq_length:
722-
hyp_word_ends[i] = shorter_seq_length
723-
hyp_word_starts[i] = min(hyp_word_starts[i], hyp_word_ends[i] - 1)
700+
if self.hyp_seq_tag_transform is not None:
701+
ref_seq_tag = self.hyp_seq_tag_transform(hyp_seq_tag)
724702
else:
725-
break
726-
for i in range(len(ref_word_ends) - 1, 0, -1):
727-
if ref_word_ends[i] > shorter_seq_length:
728-
ref_word_ends[i] = shorter_seq_length
729-
ref_word_starts[i] = min(ref_word_starts[i], ref_word_ends[i] - 1)
730-
else:
731-
break
732-
733-
seq_word_start_diffs = [start - ref_start for start, ref_start in zip(hyp_word_starts, ref_word_starts)]
734-
seq_word_end_diffs = [end - ref_end for end, ref_end in zip(hyp_word_ends, ref_word_ends)]
735-
736-
# Optionally remove outliers
737-
seq_word_start_diffs = [diff for diff in seq_word_start_diffs if abs(diff) <= self.remove_outlier_limit]
738-
seq_word_end_diffs = [diff for diff in seq_word_end_diffs if abs(diff) <= self.remove_outlier_limit]
739-
740-
seq_differences = seq_word_start_diffs + seq_word_end_diffs
741-
742-
start_differences.update(seq_word_start_diffs)
743-
end_differences.update(seq_word_end_diffs)
744-
differences.update(seq_differences)
745-
746-
if seq_differences:
747-
seq_tse = statistics.mean(abs(diff) for diff in seq_differences)
748-
749-
logging.info(
750-
f"Sequence {hyp_seq_tag} ({idx} / {len(file_list)}):\n Word start distances are {seq_word_start_diffs}\n Word end distances are {seq_word_end_diffs}\n Sequence TSE is {seq_tse} frames"
703+
ref_seq_tag = hyp_seq_tag
704+
705+
ref_word_starts, ref_word_ends, ref_seq_length = self._compute_word_boundaries(
706+
ref_alignments,
707+
ref_allophone_map,
708+
ref_seq_tag,
709+
self.ref_silence_phone,
710+
self.ref_upsample_factor,
751711
)
752-
counted_seqs += 1
753-
else:
754-
logging.warning(
755-
f"Sequence {hyp_seq_tag} ({idx} / {len(file_list)}):\n Discarded since all distances are over the upper limit"
712+
assert len(ref_word_starts) == len(ref_word_ends), (
713+
f"Found different number of word starts ({len(hyp_word_starts)}) "
714+
f"than word ends ({len(hyp_word_ends)}) in reference. Something seems to be broken."
756715
)
757-
discarded_seqs += 1
758-
continue
716+
717+
if len(hyp_word_starts) != len(ref_word_starts):
718+
logging.warning(
719+
f"Sequence {hyp_seq_tag} ({idx} / {len(file_list)}:\n Discarded because the number of words in alignment ({len(hyp_word_starts)}) does not equal the number of words in reference ({len(ref_word_starts)})."
720+
)
721+
discarded_seqs += 1
722+
continue
723+
724+
# Sometimes different feature extraction or subsampling may produce mismatched lengths that are different by a few frames, so cut off at the shorter length
725+
shorter_seq_length = min(hyp_seq_length, ref_seq_length)
726+
727+
for i in range(len(hyp_word_ends) - 1, 0, -1):
728+
if hyp_word_ends[i] > shorter_seq_length:
729+
hyp_word_ends[i] = shorter_seq_length
730+
hyp_word_starts[i] = min(hyp_word_starts[i], hyp_word_ends[i] - 1)
731+
else:
732+
break
733+
for i in range(len(ref_word_ends) - 1, 0, -1):
734+
if ref_word_ends[i] > shorter_seq_length:
735+
ref_word_ends[i] = shorter_seq_length
736+
ref_word_starts[i] = min(ref_word_starts[i], ref_word_ends[i] - 1)
737+
else:
738+
break
739+
740+
seq_word_start_diffs = [start - ref_start for start, ref_start in zip(hyp_word_starts, ref_word_starts)]
741+
seq_word_end_diffs = [end - ref_end for end, ref_end in zip(hyp_word_ends, ref_word_ends)]
742+
743+
# Optionally remove outliers
744+
seq_word_start_diffs = [diff for diff in seq_word_start_diffs if abs(diff) <= self.remove_outlier_limit]
745+
seq_word_end_diffs = [diff for diff in seq_word_end_diffs if abs(diff) <= self.remove_outlier_limit]
746+
747+
seq_differences = seq_word_start_diffs + seq_word_end_diffs
748+
749+
start_differences.update(seq_word_start_diffs)
750+
end_differences.update(seq_word_end_diffs)
751+
differences.update(seq_differences)
752+
753+
if seq_differences:
754+
seq_tse = statistics.mean(abs(diff) for diff in seq_differences)
755+
tse_dict[ref_seq_tag] = (hyp_seq_tag, seq_tse)
756+
757+
logging.info(
758+
f"Sequence {hyp_seq_tag} ({idx} / {len(file_list)}):\n Word start distances are {seq_word_start_diffs}\n Word end distances are {seq_word_end_diffs}\n Sequence TSE is {seq_tse} frames"
759+
)
760+
counted_seqs += 1
761+
else:
762+
logging.warning(
763+
f"Sequence {hyp_seq_tag} ({idx} / {len(file_list)}):\n Discarded since all distances are over the upper limit"
764+
)
765+
discarded_seqs += 1
766+
continue
759767

760768
logging.info(
761769
f"Processing finished. Computed TSE value based on {counted_seqs} sequences; {discarded_seqs} sequences were discarded."
@@ -767,6 +775,11 @@ def run(self) -> None:
767775
self.out_word_end_frame_differences.set({key: end_differences[key] for key in sorted(end_differences.keys())})
768776
self.out_boundary_frame_differences.set({key: differences[key] for key in sorted(differences.keys())})
769777
self.out_tse_frames.set(statistics.mean(abs(diff) for diff in differences.elements()))
778+
with util.uopen(self.out_tse_differences_file.get_path(), "wt") as f:
779+
for ref_seq_tag, (hyp_seq_tag, avg_tse) in sorted(
780+
tse_dict.items(), key=lambda k_v: k_v[1][1], reverse=True
781+
):
782+
f.write(f"{ref_seq_tag}\t{hyp_seq_tag}\t{avg_tse}\n")
770783

771784
def plot(self):
772785
for descr, dict_file, plot_file in [

0 commit comments

Comments
 (0)