|
1 | 1 | __all__ = [ |
| 2 | + "get_segment_name_to_alignment_mapping", |
2 | 3 | "AlignmentJob", |
3 | 4 | "DumpAlignmentJob", |
4 | 5 | "PlotAlignmentJob", |
5 | 6 | "AMScoresFromAlignmentLogJob", |
6 | 7 | "ComputeTimeStampErrorJob", |
7 | 8 | "GetLongestAllophoneFileJob", |
| 9 | + "PlotViterbiAlignmentJob", |
8 | 10 | ] |
9 | 11 |
|
10 | 12 | import itertools |
|
13 | 15 | import os |
14 | 16 | import shutil |
15 | 17 | import statistics |
| 18 | +from typing import Callable, Counter, Dict, List, Optional, Tuple, Union |
16 | 19 | import xml.etree.ElementTree as ET |
17 | | -from typing import Callable, Counter, List, Optional, Tuple, Union |
18 | 20 |
|
| 21 | +import numpy as np |
19 | 22 | from sisyphus import * |
20 | 23 |
|
21 | | -Path = setup_path(__package__) |
22 | | - |
| 24 | +import i6_core.lib.corpus as corpus |
23 | 25 | import i6_core.lib.rasr_cache as rasr_cache |
24 | 26 | import i6_core.rasr as rasr |
25 | 27 | import i6_core.util as util |
26 | 28 |
|
27 | 29 | from .flow import alignment_flow, dump_alignment_flow |
28 | 30 |
|
29 | 31 |
|
| 32 | +Path = setup_path(__package__) |
| 33 | + |
| 34 | + |
| 35 | +_SegmentNameToAlignmentType = Dict[str, List[Tuple[int, int, int, float]]] |
| 36 | +"""Mapping from segment names to `(timestamp, allophone_id, hmm_state, alignment_weight)`.""" |
| 37 | + |
| 38 | + |
| 39 | +def get_segment_name_to_alignment_mapping(alignment_cache: rasr_cache.FileArchive) -> _SegmentNameToAlignmentType: |
| 40 | + """ |
| 41 | + :param alignment_cache: Opened alignment cache from which to extract the alignments. |
| 42 | + :return: Mapping from segment names to alignments (by frame). |
| 43 | + The alignments are a list of tuples (timestamp, allophone_id, hmm_state, alignment_weight). |
| 44 | + """ |
| 45 | + return { |
| 46 | + segment_name: alignment_cache.read(segment_name, "align") |
| 47 | + for segment_name in alignment_cache.ft.keys() |
| 48 | + if not segment_name.endswith(".attribs") |
| 49 | + } |
| 50 | + |
| 51 | + |
30 | 52 | class AlignmentJob(rasr.RasrCommand, Job): |
31 | 53 | """ |
32 | 54 | Align a dataset with the given feature scorer. |
@@ -153,7 +175,6 @@ def run(self, task_id): |
153 | 175 | ) |
154 | 176 |
|
155 | 177 | def plot(self): |
156 | | - import numpy as np |
157 | 178 | import matplotlib |
158 | 179 | import matplotlib.pyplot as plt |
159 | 180 |
|
@@ -464,7 +485,6 @@ def tasks(self): |
464 | 485 | yield Task("plot", resume="plot", rqmt=self.rqmt) |
465 | 486 |
|
466 | 487 | def plot(self): |
467 | | - import numpy as np |
468 | 488 | import matplotlib |
469 | 489 | import matplotlib.pyplot as plt |
470 | 490 |
|
@@ -825,3 +845,163 @@ def run(self): |
825 | 845 | line_set = {*lines} - {None} |
826 | 846 | assert len(line_set) == 1, f"Line {i}: expected only one allophone, but found two or more: {line_set}." |
827 | 847 | f.write(list(line_set)[0]) |
| 848 | + |
| 849 | + |
| 850 | +class PlotViterbiAlignmentJob(Job): |
| 851 | + """ |
| 852 | + Plots the alignments of each segment in the specified alignment files. |
| 853 | + """ |
| 854 | + |
| 855 | + def __init__( |
| 856 | + self, |
| 857 | + alignment_caches: List[tk.Path], |
| 858 | + allophone_file: tk.Path, |
| 859 | + segment_names_to_plot: Optional[tk.Path] = None, |
| 860 | + corpus_file: Optional[tk.Path] = None, |
| 861 | + ): |
| 862 | + """ |
| 863 | + :param alignment_caches: Alignment files to be plotted. |
| 864 | + :param allophone_file: Allophone file used in the alignment process. |
| 865 | + :param segment_names_to_plot: Specific segment names to plot. |
| 866 | + By default, plot all segments given in :param:`alignment_caches`. |
| 867 | + :param corpus_file: Corpus used to generate the alignments. By default, the plots have no title. |
| 868 | + If provided, the plots will have the text from the respective segment as title, |
| 869 | + whenever the segment is available in the corpus. This should only be given for convenience. |
| 870 | + """ |
| 871 | + self.alignment_caches = alignment_caches |
| 872 | + self.allophone_file = allophone_file |
| 873 | + self.segment_names_to_plot = segment_names_to_plot |
| 874 | + self.corpus_file = corpus_file |
| 875 | + |
| 876 | + self.out_plot_dir = self.output_path("plots", directory=True) |
| 877 | + |
| 878 | + self.rqmt = {"cpu": 1, "mem": 2.0, "time": 1.0} |
| 879 | + |
| 880 | + def tasks(self): |
| 881 | + yield Task("run", resume="run", rqmt=self.rqmt, args=range(1, len(self.alignment_caches) + 1)) |
| 882 | + |
| 883 | + def extract_phoneme_sequence(self, alignment: np.array) -> Tuple[np.array, np.array]: |
| 884 | + """ |
| 885 | + :param alignment: Monophone alignment, for instance: `np.array(["a", "a", "b", "b", "b", "c", ...])`. |
| 886 | + :return: |
| 887 | + - Monophone sequence (ordered as provided in :param:`alignment`). |
| 888 | +
|
| 889 | + - **Indices** corresponding to the monophone sequence from the Viterbi alignment. |
| 890 | + In the example above, these would be `[0, 0, 1, 1, 1, 2, ...]`. |
| 891 | + """ |
| 892 | + boundaries = np.concatenate( |
| 893 | + [ |
| 894 | + np.where(alignment[:-1] != alignment[1:])[0], |
| 895 | + [len(alignment) - 1], # manually add boundary of last allophone |
| 896 | + ] |
| 897 | + ) |
| 898 | + |
| 899 | + lengths = boundaries - np.concatenate([[-1], boundaries[:-1]]) |
| 900 | + phonemes = alignment[boundaries] |
| 901 | + monotonic_idx_alignment = np.repeat(np.arange(len(phonemes)), lengths) |
| 902 | + return phonemes, monotonic_idx_alignment |
| 903 | + |
| 904 | + def make_viterbi_matrix(self, label_indices: np.array) -> np.array: |
| 905 | + """ |
| 906 | + :param label_indices: Sequence of label (allophone) indices, |
| 907 | + corresponding to the monophone sequence from the Viterbi alignment. |
| 908 | +
|
| 909 | + For example, for an alignment of `np.array(["a", "a", "b", "b", "b", "c", ...])`, |
| 910 | + :param:`label_indices` would be `[0, 0, 1, 1, 1, 2, ...]`. |
| 911 | + :return: Matrix corresponding to the Viterbi alignment. |
| 912 | + """ |
| 913 | + num_timestamps = len(label_indices) |
| 914 | + num_allophones = max(label_indices) + 1 |
| 915 | + # Place the timestamps on the Y axis because we'll map the label indices to the different phonemes there. |
| 916 | + viterbi_matrix = np.zeros((num_allophones, num_timestamps), dtype=bool) |
| 917 | + for i, t_i in enumerate(label_indices): |
| 918 | + viterbi_matrix[t_i, i] = True |
| 919 | + return viterbi_matrix |
| 920 | + |
| 921 | + def plot(self, viterbi_matrix: np.array, allophone_sequence: List[str], file_name: str, title: str = ""): |
| 922 | + """ |
| 923 | + :param viterbi_matrix: Matrix to be plotted, corresponding to the Viterbi alignment. |
| 924 | + :param allophone_sequence: Allophone sequence (Y-axis tick labels). |
| 925 | + :param file_name: File name where to store the plot, relative to `<job>/output/plots/`. |
| 926 | + :param title: Optional title to add to the image. By default there will be no title. |
| 927 | + :return: Plot corresponding to the monotonic alignment. |
| 928 | + """ |
| 929 | + import matplotlib |
| 930 | + import matplotlib.pyplot as plt |
| 931 | + |
| 932 | + matplotlib.use("Agg") |
| 933 | + |
| 934 | + num_allophones, num_timestamps = np.shape(viterbi_matrix) |
| 935 | + |
| 936 | + fig, ax = plt.subplots(figsize=(10, 10)) |
| 937 | + ax.set_xlabel("Frame") |
| 938 | + ax.xaxis.set_label_coords(0.98, -0.03) |
| 939 | + ax.set_xbound(0, num_timestamps - 1) |
| 940 | + ax.set_ybound(-0.5, num_allophones - 0.5) |
| 941 | + |
| 942 | + ax.set_yticks(np.arange(num_allophones)) |
| 943 | + ax.set_yticklabels(allophone_sequence) |
| 944 | + |
| 945 | + ax.set_title(title) |
| 946 | + |
| 947 | + ax.imshow(viterbi_matrix, cmap="Blues", interpolation="none", aspect="auto", origin="lower") |
| 948 | + |
| 949 | + # The plot will be purposefully divided into subdirectories. |
| 950 | + os.makedirs(os.path.dirname(os.path.join(self.out_plot_dir.get_path(), file_name)), exist_ok=True) |
| 951 | + fig.savefig(os.path.join(self.out_plot_dir.get_path(), f"{file_name}.png")) |
| 952 | + matplotlib.pyplot.close(fig) |
| 953 | + |
| 954 | + def run(self, task_id: int): |
| 955 | + if self.segment_names_to_plot is not None: |
| 956 | + # Load the segment names to plot. |
| 957 | + with util.uopen(self.segment_names_to_plot.get_path(), "rt") as f: |
| 958 | + segment_names_to_plot = {seg_name.strip() for seg_name in f} |
| 959 | + # Load the segment names from the alignment caches. |
| 960 | + align_cache = rasr_cache.FileArchive(self.alignment_caches[task_id - 1].get_path()) |
| 961 | + align_cache.setAllophones(self.allophone_file.get_path()) |
| 962 | + seg_name_to_alignments = { |
| 963 | + seg_name: alignments |
| 964 | + for seg_name, alignments in get_segment_name_to_alignment_mapping(align_cache).items() |
| 965 | + # Only load the specific segment names that the user has provided. |
| 966 | + if seg_name in segment_names_to_plot |
| 967 | + } |
| 968 | + else: |
| 969 | + # Load the segment names from the alignment caches. |
| 970 | + align_cache = rasr_cache.FileArchive(self.alignment_caches[task_id - 1].get_path()) |
| 971 | + align_cache.setAllophones(self.allophone_file.get_path()) |
| 972 | + seg_name_to_alignments = get_segment_name_to_alignment_mapping(align_cache) |
| 973 | + # Plot everything from the local alignment cache. |
| 974 | + segment_names_to_plot = seg_name_to_alignments.keys() |
| 975 | + |
| 976 | + seg_name_to_text = {} |
| 977 | + if self.corpus_file is not None: |
| 978 | + c = corpus.Corpus() |
| 979 | + c.load(self.corpus_file.get_path()) |
| 980 | + seg_name_to_text = {seg_name: segment.full_orth() for seg_name, segment in c.get_segment_mapping().items()} |
| 981 | + |
| 982 | + empty_alignment_seg_names = [] |
| 983 | + for seg_name in segment_names_to_plot: |
| 984 | + alignments = seg_name_to_alignments.get(seg_name, None) |
| 985 | + if alignments is None: |
| 986 | + continue |
| 987 | + # In some rare cases, the alignment doesn't have to reach a satisfactory end. |
| 988 | + # In these cases, the final alignment is empty. Skip those cases. |
| 989 | + if len(alignments) == 0: |
| 990 | + empty_alignment_seg_names.append(seg_name) |
| 991 | + continue |
| 992 | + |
| 993 | + for i, (_, allo_id, _, _) in enumerate(alignments): |
| 994 | + allophone = align_cache.allophones[allo_id] |
| 995 | + # Get the central part of the allophone. |
| 996 | + seg_name_to_alignments[seg_name][i] = allophone.split("{")[0] |
| 997 | + |
| 998 | + center_allophones = np.array(seg_name_to_alignments[seg_name]) |
| 999 | + phonemes, alignment_indices = self.extract_phoneme_sequence(center_allophones) |
| 1000 | + viterbi_matrix = self.make_viterbi_matrix(alignment_indices) |
| 1001 | + self.plot(viterbi_matrix, phonemes, file_name=seg_name, title=seg_name_to_text.get(seg_name, "")) |
| 1002 | + |
| 1003 | + if empty_alignment_seg_names: |
| 1004 | + logging.warning( |
| 1005 | + "The following alignments weren't plotted because their alignments were empty:\n" |
| 1006 | + f"{empty_alignment_seg_names}" |
| 1007 | + ) |
0 commit comments