diff --git a/lib/hdf.py b/lib/hdf.py index 52705713..79ea2210 100644 --- a/lib/hdf.py +++ b/lib/hdf.py @@ -1,7 +1,8 @@ import h5py import numpy as np -from typing import Dict, Optional -import sys +import logging, sys, shutil, tempfile + +from typing import Dict, List, Optional def get_input_dict_from_returnn_hdf(hdf_file: h5py.File) -> Dict[str, np.ndarray]: @@ -35,3 +36,95 @@ def get_returnn_simple_hdf_writer(returnn_root: Optional[str]): from returnn.datasets.hdf import SimpleHDFWriter return SimpleHDFWriter + + +class NextGenHDFWriter: + """ + This class is a helper for writing the of returnn NextGenHDFDataset + """ + + def __init__( + self, + filename: str, + label_info_dict: Dict, + feature_names: Optional[List[str]] = None, + label_data_type: type = np.uint16, + label_parser_name: str = "sparse", + feature_parser_name: str = "feature_sequence", + ): + """ + :param label_info_dict: a dictionay with the label targets used in returnn training as key and numebr of label classes as value + :param feature_names: additional feature data names + :param label_data_type: type that is used to store the data + :param label_parser_name: this should be checked against returnn implementations + "param feature_parser_name: as above + """ + self.label_info_dict = label_info_dict + self.label_parser_name = label_parser_name + self.feature_names = feature_names + if feature_names is not None: + self.feature_parser_name = feature_parser_name + self.label_data_type = label_data_type + self.string_data_type = h5py.special_dtype(vlen=str) + self.sequence_names = [] + self.group_holder_dict = {} + + self.file_init() + + def file_init(self): + self.temp_file = tempfile.NamedTemporaryFile(suffix="_NextGenHDFWriter_outHDF") + self.temp_path = self.temp_file.name + self.out_hdf = h5py.File(self.temp_path, "w") + + logging.info(f"processing temporary file { self.temp_path}") + + # root + self.root_group = self.out_hdf.create_group("streams") + + for label_name, label_dim in self.label_info_dict.items(): + self.group_holder_dict[label_name] = self._get_label_group(label_name, label_dim) + + if self.feature_names is not None: + for feat_name in self.feature_names: + self.group_holder_dict[feat_name] = self._get_feature_group(feat_name) + + def _get_label_group(self, label_name, label_dim): + assert label_dim > 0, "you should have at least dim 1" + label_group = self.root_group.create_group(label_name) + label_group.attrs["parser"] = "sparse" + label_group.create_dataset( + "feature_names", + data=[b"label_%d" % l for l in range(label_dim)], + dtype=self.string_data_type, + ) + + return label_group.create_group("data") + + def _get_feature_group(self, feature_name): + feature_group = self.root_group.create_group(feature_name) + feature_group.attrs["parser"] = self.feature_parser_name + + return feature_group.create_group("data") + + def add_sequence_name(self, seq_name): + self.sequence_names.append(seq_name) + + def add_data_to_group(self, group_name, seq_name, data): + if group_name in self.label_info_dict: + data = np.array(data).astype(self.label_data_type) + + # the / in the string would lead to more hierarchies automatically, thus substitute + self.group_holder_dict[group_name].create_dataset(seq_name.replace("/", "\\"), data=data) + + def finalize(self, filename): + seq_name_set = set([s.replace("/", "\\") for s in self.sequence_names]) + + for k, group in self.group_holder_dict.items(): + assert set(group.keys()) == seq_name_set, "The sequence names do not match between groups" + + self.out_hdf.create_dataset( + "seq_names", data=[s.encode() for s in self.sequence_names], dtype=self.string_data_type + ) + + self.out_hdf.close() + shutil.move(self.temp_path, filename) diff --git a/returnn/hdf.py b/returnn/hdf.py index ee5a32ec..75d30a44 100644 --- a/returnn/hdf.py +++ b/returnn/hdf.py @@ -1,4 +1,10 @@ -__all__ = ["ReturnnDumpHDFJob", "ReturnnRasrDumpHDFJob", "BlissToPcmHDFJob", "RasrAlignmentDumpHDFJob"] +__all__ = [ + "ReturnnDumpHDFJob", + "ReturnnRasrDumpHDFJob", + "BlissToPcmHDFJob", + "RasrAlignmentDumpHDFJob", + "RasrDumpNextGenHDFJob", +] from dataclasses import dataclass import glob @@ -8,11 +14,11 @@ import soundfile as sf import subprocess as sp import tempfile -from typing import List, Optional +from typing import Dict, List, Optional from .rasr_training import ReturnnRasrTrainingJob from i6_core.lib import corpus -from i6_core.lib.hdf import get_returnn_simple_hdf_writer +from i6_core.lib.hdf import get_returnn_simple_hdf_writer, NextGenHDFWriter from i6_core.lib.rasr_cache import FileArchive import i6_core.rasr as rasr from i6_core.util import instanciate_delayed, uopen, write_paths_to_file @@ -371,3 +377,135 @@ def run(self, task_id): if len(excluded_segments): write_paths_to_file(f"excluded_segments.{task_id}", excluded_segments) + + +class RasrDumpNextGenHDFJob(Job): + """ + This Job reads Rasr alignment and feature caches and dump them in hdf files for NextGenHDFDataset class. + """ + + def __init__( + self, + alignment_caches_dict: Dict[str, List[tk.Path]], + allophones: [tk.Path, Dict[str, tk.Path]], + state_tyings: [tk.Path, Dict[str, tk.Path]], + reference_target: str, + data_type: type = np.uint16, + feature_caches_dict: Optional[Dict[str, List[tk.Path]]] = None, + ): + """ + :param alignment_caches_dict: the dict keys are the target strings used in returnn training, values are output of an AlignmentJob + :param allophones: e.g. output of a StoreAllophonesJob or a dict as above with same keys as alignment + :param state_tyings: e.g. output of a DumpStateTyingJob or a dict as above with same keys as alignment + :param reference_target: is one of the keys of alignment_caches that would be taken as reference for reading segments + :param data_type: type that is used to store the data + :param returnn_root: file path to the RETURNN repository root folder + :param feature_caches_dict: similar to the alignment_caches_dict just for features + """ + self.alignment_caches_dict = alignment_caches_dict + self.feature_caches_dict = feature_caches_dict + self.allophones = allophones + self.state_tyings = state_tyings + self.reference_target = reference_target + self.out_hdf_files = [ + self.output_path(f"data.hdf.{d}") for d in range(len(self.alignment_caches_dict[reference_target])) + ] + self.out_excluded_segments = self.output_path(f"excluded.segments") + self.data_type = data_type + self.rqmt = {"cpu": 1, "mem": 8, "time": 0.5} + + def tasks(self): + yield Task("run", rqmt=self.rqmt, args=range(1, (len(self.out_hdf_files) + 1))) + yield Task("merge", mini_task=True) + + def _get_state_tying(self, state_tying_file): + return dict((k, int(v)) for l in open(state_tying_file.get_path()) for k, v in [l.strip().split()[0:2]]) + + def _get_alignment_cache(self, task_id, alignment_name, allophones): + alignment_cache = FileArchive(self.alignment_caches_dict[alignment_name][task_id - 1].get_path()) + _ = alignment_cache.setAllophones(allophones.get_path()) + + return alignment_cache + + def _get_label_sequence(self, alignment_cache, file, state_tying): + targets = [] + alignment = alignment_cache.read(file, "align") + if not len(alignment): + return None + alignmentStates = ["%s.%d" % (alignment_cache.allophones[t[1]], t[2]) for t in alignment] + for allophone in alignmentStates: + targets.append(state_tying[allophone]) + data = np.array(targets).astype(np.dtype(self.data_type)) + + return data + + def merge(self): + excluded_segments = [] + excluded_files = glob.glob("excluded_segments.*") + for p in excluded_files: + if os.path.isfile(p): + with open(p, "r") as f: + segments = f.read().splitlines() + excluded_segments.extend(segments) + + write_paths_to_file(self.out_excluded_segments, excluded_segments) + + def run(self, task_id): + # this is first used to initialize the writer and then to contain the caches + alignment_dict = dict.fromkeys(self.alignment_caches_dict.keys(), None) + feature_names = list(self.feature_caches_dict.keys()) + + assert ( + self.reference_target in alignment_dict + ), "you did not define a proper target for reference alignment cache" + + allophones = {} + state_tyings = {} + for k in alignment_dict.keys(): + allophones[k] = self.allophones if not isinstance(self.allophones, dict) else self.allophones[k] + state_tying_path = self.state_tyings if not isinstance(self.state_tyings, dict) else self.state_tyings[k] + state_tyings[k] = self._get_state_tying(state_tying_path) + alignment_dict[k] = state_tyings[k][max(state_tyings[k])] + 1 # max label class id + 1 + + hdf_writer = NextGenHDFWriter( + filename=f"hdf.{task_id - 1}", + label_info_dict=alignment_dict, + feature_names=feature_names, + label_data_type=self.data_type, + ) + + for k in alignment_dict.keys(): + alignment_dict[k] = self._get_alignment_cache(task_id, k, allophones[k]) + + feature_dict = dict( + zip( + feature_names, [FileArchive(self.feature_caches_dict[n][task_id - 1].get_path()) for n in feature_names] + ) + ) + + excluded_segments = [] + + for file in alignment_dict[self.reference_target].ft: + info = alignment_dict[self.reference_target].ft[file] + if info.name.endswith(".attribs"): + continue + seq_name = info.name + + for align_k in alignment_dict.keys(): + label_seq = self._get_label_sequence(alignment_dict[align_k], file, state_tyings[align_k]) + if label_seq is None: + if seq_name not in excluded_segments: + excluded_segments.append(seq_name) + continue + hdf_writer.add_data_to_group(align_k, seq_name, label_seq) + + for feat_key in feature_dict.keys(): + times, features = feature_dict[feat_key].read(file, "feat") + hdf_writer.add_data_to_group(feat_key, seq_name, features) + + hdf_writer.add_sequence_name(seq_name) + + hdf_writer.finalize(self.out_hdf_files[task_id - 1]) + + if len(excluded_segments): + write_paths_to_file(f"excluded_segments.{task_id}", excluded_segments)