Skip to content

Commit f48ebf5

Browse files
authored
Add job to plot Viterbi alignment (#558)
The job plots alignments from a set of alignment caches in an upward trend, as expected from a usual alignment plot.
1 parent 645e1fe commit f48ebf5

File tree

1 file changed

+185
-5
lines changed

1 file changed

+185
-5
lines changed

mm/alignment.py

Lines changed: 185 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
__all__ = [
2+
"get_segment_name_to_alignment_mapping",
23
"AlignmentJob",
34
"DumpAlignmentJob",
45
"PlotAlignmentJob",
56
"AMScoresFromAlignmentLogJob",
67
"ComputeTimeStampErrorJob",
78
"GetLongestAllophoneFileJob",
9+
"PlotViterbiAlignmentJob",
810
]
911

1012
import itertools
@@ -13,20 +15,40 @@
1315
import os
1416
import shutil
1517
import statistics
18+
from typing import Callable, Counter, Dict, List, Optional, Tuple, Union
1619
import xml.etree.ElementTree as ET
17-
from typing import Callable, Counter, List, Optional, Tuple, Union
1820

21+
import numpy as np
1922
from sisyphus import *
2023

21-
Path = setup_path(__package__)
22-
24+
import i6_core.lib.corpus as corpus
2325
import i6_core.lib.rasr_cache as rasr_cache
2426
import i6_core.rasr as rasr
2527
import i6_core.util as util
2628

2729
from .flow import alignment_flow, dump_alignment_flow
2830

2931

32+
Path = setup_path(__package__)
33+
34+
35+
_SegmentNameToAlignmentType = Dict[str, List[Tuple[int, int, int, float]]]
36+
"""Mapping from segment names to `(timestamp, allophone_id, hmm_state, alignment_weight)`."""
37+
38+
39+
def get_segment_name_to_alignment_mapping(alignment_cache: rasr_cache.FileArchive) -> _SegmentNameToAlignmentType:
40+
"""
41+
:param alignment_cache: Opened alignment cache from which to extract the alignments.
42+
:return: Mapping from segment names to alignments (by frame).
43+
The alignments are a list of tuples (timestamp, allophone_id, hmm_state, alignment_weight).
44+
"""
45+
return {
46+
segment_name: alignment_cache.read(segment_name, "align")
47+
for segment_name in alignment_cache.ft.keys()
48+
if not segment_name.endswith(".attribs")
49+
}
50+
51+
3052
class AlignmentJob(rasr.RasrCommand, Job):
3153
"""
3254
Align a dataset with the given feature scorer.
@@ -153,7 +175,6 @@ def run(self, task_id):
153175
)
154176

155177
def plot(self):
156-
import numpy as np
157178
import matplotlib
158179
import matplotlib.pyplot as plt
159180

@@ -464,7 +485,6 @@ def tasks(self):
464485
yield Task("plot", resume="plot", rqmt=self.rqmt)
465486

466487
def plot(self):
467-
import numpy as np
468488
import matplotlib
469489
import matplotlib.pyplot as plt
470490

@@ -825,3 +845,163 @@ def run(self):
825845
line_set = {*lines} - {None}
826846
assert len(line_set) == 1, f"Line {i}: expected only one allophone, but found two or more: {line_set}."
827847
f.write(list(line_set)[0])
848+
849+
850+
class PlotViterbiAlignmentJob(Job):
851+
"""
852+
Plots the alignments of each segment in the specified alignment files.
853+
"""
854+
855+
def __init__(
856+
self,
857+
alignment_caches: List[tk.Path],
858+
allophone_file: tk.Path,
859+
segment_names_to_plot: Optional[tk.Path] = None,
860+
corpus_file: Optional[tk.Path] = None,
861+
):
862+
"""
863+
:param alignment_caches: Alignment files to be plotted.
864+
:param allophone_file: Allophone file used in the alignment process.
865+
:param segment_names_to_plot: Specific segment names to plot.
866+
By default, plot all segments given in :param:`alignment_caches`.
867+
:param corpus_file: Corpus used to generate the alignments. By default, the plots have no title.
868+
If provided, the plots will have the text from the respective segment as title,
869+
whenever the segment is available in the corpus. This should only be given for convenience.
870+
"""
871+
self.alignment_caches = alignment_caches
872+
self.allophone_file = allophone_file
873+
self.segment_names_to_plot = segment_names_to_plot
874+
self.corpus_file = corpus_file
875+
876+
self.out_plot_dir = self.output_path("plots", directory=True)
877+
878+
self.rqmt = {"cpu": 1, "mem": 2.0, "time": 1.0}
879+
880+
def tasks(self):
881+
yield Task("run", resume="run", rqmt=self.rqmt, args=range(1, len(self.alignment_caches) + 1))
882+
883+
def extract_phoneme_sequence(self, alignment: np.array) -> Tuple[np.array, np.array]:
884+
"""
885+
:param alignment: Monophone alignment, for instance: `np.array(["a", "a", "b", "b", "b", "c", ...])`.
886+
:return:
887+
- Monophone sequence (ordered as provided in :param:`alignment`).
888+
889+
- **Indices** corresponding to the monophone sequence from the Viterbi alignment.
890+
In the example above, these would be `[0, 0, 1, 1, 1, 2, ...]`.
891+
"""
892+
boundaries = np.concatenate(
893+
[
894+
np.where(alignment[:-1] != alignment[1:])[0],
895+
[len(alignment) - 1], # manually add boundary of last allophone
896+
]
897+
)
898+
899+
lengths = boundaries - np.concatenate([[-1], boundaries[:-1]])
900+
phonemes = alignment[boundaries]
901+
monotonic_idx_alignment = np.repeat(np.arange(len(phonemes)), lengths)
902+
return phonemes, monotonic_idx_alignment
903+
904+
def make_viterbi_matrix(self, label_indices: np.array) -> np.array:
905+
"""
906+
:param label_indices: Sequence of label (allophone) indices,
907+
corresponding to the monophone sequence from the Viterbi alignment.
908+
909+
For example, for an alignment of `np.array(["a", "a", "b", "b", "b", "c", ...])`,
910+
:param:`label_indices` would be `[0, 0, 1, 1, 1, 2, ...]`.
911+
:return: Matrix corresponding to the Viterbi alignment.
912+
"""
913+
num_timestamps = len(label_indices)
914+
num_allophones = max(label_indices) + 1
915+
# Place the timestamps on the Y axis because we'll map the label indices to the different phonemes there.
916+
viterbi_matrix = np.zeros((num_allophones, num_timestamps), dtype=bool)
917+
for i, t_i in enumerate(label_indices):
918+
viterbi_matrix[t_i, i] = True
919+
return viterbi_matrix
920+
921+
def plot(self, viterbi_matrix: np.array, allophone_sequence: List[str], file_name: str, title: str = ""):
922+
"""
923+
:param viterbi_matrix: Matrix to be plotted, corresponding to the Viterbi alignment.
924+
:param allophone_sequence: Allophone sequence (Y-axis tick labels).
925+
:param file_name: File name where to store the plot, relative to `<job>/output/plots/`.
926+
:param title: Optional title to add to the image. By default there will be no title.
927+
:return: Plot corresponding to the monotonic alignment.
928+
"""
929+
import matplotlib
930+
import matplotlib.pyplot as plt
931+
932+
matplotlib.use("Agg")
933+
934+
num_allophones, num_timestamps = np.shape(viterbi_matrix)
935+
936+
fig, ax = plt.subplots(figsize=(10, 10))
937+
ax.set_xlabel("Frame")
938+
ax.xaxis.set_label_coords(0.98, -0.03)
939+
ax.set_xbound(0, num_timestamps - 1)
940+
ax.set_ybound(-0.5, num_allophones - 0.5)
941+
942+
ax.set_yticks(np.arange(num_allophones))
943+
ax.set_yticklabels(allophone_sequence)
944+
945+
ax.set_title(title)
946+
947+
ax.imshow(viterbi_matrix, cmap="Blues", interpolation="none", aspect="auto", origin="lower")
948+
949+
# The plot will be purposefully divided into subdirectories.
950+
os.makedirs(os.path.dirname(os.path.join(self.out_plot_dir.get_path(), file_name)), exist_ok=True)
951+
fig.savefig(os.path.join(self.out_plot_dir.get_path(), f"{file_name}.png"))
952+
matplotlib.pyplot.close(fig)
953+
954+
def run(self, task_id: int):
955+
if self.segment_names_to_plot is not None:
956+
# Load the segment names to plot.
957+
with util.uopen(self.segment_names_to_plot.get_path(), "rt") as f:
958+
segment_names_to_plot = {seg_name.strip() for seg_name in f}
959+
# Load the segment names from the alignment caches.
960+
align_cache = rasr_cache.FileArchive(self.alignment_caches[task_id - 1].get_path())
961+
align_cache.setAllophones(self.allophone_file.get_path())
962+
seg_name_to_alignments = {
963+
seg_name: alignments
964+
for seg_name, alignments in get_segment_name_to_alignment_mapping(align_cache).items()
965+
# Only load the specific segment names that the user has provided.
966+
if seg_name in segment_names_to_plot
967+
}
968+
else:
969+
# Load the segment names from the alignment caches.
970+
align_cache = rasr_cache.FileArchive(self.alignment_caches[task_id - 1].get_path())
971+
align_cache.setAllophones(self.allophone_file.get_path())
972+
seg_name_to_alignments = get_segment_name_to_alignment_mapping(align_cache)
973+
# Plot everything from the local alignment cache.
974+
segment_names_to_plot = seg_name_to_alignments.keys()
975+
976+
seg_name_to_text = {}
977+
if self.corpus_file is not None:
978+
c = corpus.Corpus()
979+
c.load(self.corpus_file.get_path())
980+
seg_name_to_text = {seg_name: segment.full_orth() for seg_name, segment in c.get_segment_mapping().items()}
981+
982+
empty_alignment_seg_names = []
983+
for seg_name in segment_names_to_plot:
984+
alignments = seg_name_to_alignments.get(seg_name, None)
985+
if alignments is None:
986+
continue
987+
# In some rare cases, the alignment doesn't have to reach a satisfactory end.
988+
# In these cases, the final alignment is empty. Skip those cases.
989+
if len(alignments) == 0:
990+
empty_alignment_seg_names.append(seg_name)
991+
continue
992+
993+
for i, (_, allo_id, _, _) in enumerate(alignments):
994+
allophone = align_cache.allophones[allo_id]
995+
# Get the central part of the allophone.
996+
seg_name_to_alignments[seg_name][i] = allophone.split("{")[0]
997+
998+
center_allophones = np.array(seg_name_to_alignments[seg_name])
999+
phonemes, alignment_indices = self.extract_phoneme_sequence(center_allophones)
1000+
viterbi_matrix = self.make_viterbi_matrix(alignment_indices)
1001+
self.plot(viterbi_matrix, phonemes, file_name=seg_name, title=seg_name_to_text.get(seg_name, ""))
1002+
1003+
if empty_alignment_seg_names:
1004+
logging.warning(
1005+
"The following alignments weren't plotted because their alignments were empty:\n"
1006+
f"{empty_alignment_seg_names}"
1007+
)

0 commit comments

Comments
 (0)