diff --git a/common/baselines/tedlium2/default_tools.py b/common/baselines/tedlium2/default_tools.py index 42390b0c8..a595e99ec 100644 --- a/common/baselines/tedlium2/default_tools.py +++ b/common/baselines/tedlium2/default_tools.py @@ -8,14 +8,15 @@ version listed here. Nevertheless, the most recent "head" should be safe to be used as well. """ from sisyphus import tk -from i6_experiments.common.tools.audio import compile_ffmpeg_binary from i6_experiments.common.tools.rasr import compile_rasr_binaries_i6mode from i6_experiments.common.tools.sctk import compile_sctk +from i6_core.tools.git import CloneGitRepositoryJob + +PACKAGE = __package__ RASR_BINARY_PATH = compile_rasr_binaries_i6mode( - branch="apptainer_tf_2_8", configure_options=["--apptainer-patch=2023-05-08_tensorflow-2.8_v1"] + configure_options=["--apptainer-patch=2023-05-08_tensorflow-2.8_v1"] ) # use most recent RASR -# RASR_BINARY_PATH = tk.Path("/work/asr4/rossenbach/neon_test/rasr_versions/rasr_no_tf/arch/linux-x86_64-standard/") assert RASR_BINARY_PATH, "Please set a specific RASR_BINARY_PATH before running the pipeline" RASR_BINARY_PATH.hash_overwrite = "TEDLIUM2_DEFAULT_RASR_BINARY_PATH" @@ -25,3 +26,14 @@ SRILM_PATH = tk.Path("/work/tools/users/luescher/srilm-1.7.3/bin/i686-m64/") SRILM_PATH.hash_overwrite = "TEDLIUM2_DEFAULT_SRILM_PATH" + +RETURNN_EXE = tk.Path( + "/usr/bin/python3", + hash_overwrite="GENERIC_RETURNN_LAUNCHER", +) + +RETURNN_RC_ROOT = CloneGitRepositoryJob( + "https://github.com/rwth-i6/returnn", + commit="11d33468ad56a6c254168560c29e77e65eb45b7c", +).out_repository +RETURNN_RC_ROOT.hash_overwrite = "TEDLIUM2_DEFAULT_RETURNN_RC_ROOT" diff --git a/common/baselines/tedlium2/gmm/baseline_args.py b/common/baselines/tedlium2/gmm/baseline_args.py index b227cccd1..3910410c3 100644 --- a/common/baselines/tedlium2/gmm/baseline_args.py +++ b/common/baselines/tedlium2/gmm/baseline_args.py @@ -3,7 +3,9 @@ from i6_experiments.common.setups.rasr import util from i6_experiments.common.datasets.tedlium2.cart import CartQuestions -from i6_experiments.common.baselines.librispeech.default_tools import SCTK_BINARY_PATH +from i6_experiments.common.baselines.tedlium2.default_tools import SCTK_BINARY_PATH + +USE_CORRECTED_APPLICATOR = True def get_init_args(): @@ -86,7 +88,7 @@ def get_monophone_args(): "extra_merge_args": None, "extra_config": None, "extra_post_config": None, - "use_corrected_applicator": False, + "use_corrected_applicator": USE_CORRECTED_APPLICATOR, } monophone_training_args = { @@ -97,7 +99,7 @@ def get_monophone_args(): "splits": 10, "accs_per_split": 2, "dump_alignment_score_report": True, - "use_corrected_applicator": False, + "use_corrected_applicator": USE_CORRECTED_APPLICATOR, } monophone_recognition_args = { @@ -184,7 +186,7 @@ def get_triphone_args(): "align_extra_rqmt": {"mem": 8}, "accumulate_extra_rqmt": {"mem": 8}, "split_extra_rqmt": {"mem": 8}, - "use_corrected_applicator": False, + "use_corrected_applicator": USE_CORRECTED_APPLICATOR, } triphone_recognition_args = { @@ -250,7 +252,7 @@ def get_vtln_args(): "align_extra_rqmt": {"mem": 8}, "accumulate_extra_rqmt": {"mem": 8}, "split_extra_rqmt": {"mem": 8}, - "use_corrected_applicator": False, + "use_corrected_applicator": USE_CORRECTED_APPLICATOR, }, } @@ -306,7 +308,7 @@ def get_sat_args(): "align_extra_rqmt": {"mem": 8}, "accumulate_extra_rqmt": {"mem": 8}, "split_extra_rqmt": {"mem": 8}, - "use_corrected_applicator": False, + "use_corrected_applicator": USE_CORRECTED_APPLICATOR, } sat_recognition_args = { @@ -320,10 +322,10 @@ def get_sat_args(): "feature_cache": "mfcc", "cache_regex": "^mfcc.*$", "cmllr_mixtures": "estimate_mixtures_sdm.tri", - "iters": [8, 10], + "iters": [8, 9, 10], "feature_flow": "uncached_mfcc+context+lda", - "pronunciation_scales": [1.0], - "lm_scales": [25], + "pronunciation_scales": [0.0], + "lm_scales": [8.0, 20.0, 25.0], "lm_lookahead": True, "lookahead_options": None, "create_lattice": True, @@ -371,7 +373,7 @@ def get_vtln_sat_args(): "align_extra_rqmt": {"mem": 8}, "accumulate_extra_rqmt": {"mem": 8}, "split_extra_rqmt": {"mem": 8}, - "use_corrected_applicator": False, + "use_corrected_applicator": USE_CORRECTED_APPLICATOR, } vtln_sat_recognition_args = { @@ -385,10 +387,10 @@ def get_vtln_sat_args(): "feature_cache": "mfcc", "cache_regex": "^mfcc.*$", "cmllr_mixtures": "estimate_mixtures_sdm.vtln", - "iters": [8, 10], + "iters": [8, 9, 10], "feature_flow": "uncached_mfcc+context+lda+vtln", - "pronunciation_scales": [1.0], - "lm_scales": [25], + "pronunciation_scales": [0.0], + "lm_scales": [25, 20, 8.0], "lm_lookahead": True, "lookahead_options": None, "create_lattice": True, diff --git a/common/baselines/tedlium2/hybrid/__init__.py b/common/baselines/tedlium2/hybrid/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/common/baselines/tedlium2/hybrid/baseline_args.py b/common/baselines/tedlium2/hybrid/baseline_args.py new file mode 100644 index 000000000..75b2ac679 --- /dev/null +++ b/common/baselines/tedlium2/hybrid/baseline_args.py @@ -0,0 +1,51 @@ +from i6_core.features import filter_width_from_channels + + +def get_gammatone_feature_extraction_args(): + return { + "gt_options": { + "minfreq": 100, + "maxfreq": 7500, + "channels": 50, + "tempint_type": "hanning", + "tempint_shift": 0.01, + "tempint_length": 0.025, + "flush_before_gap": True, + "do_specint": False, + "specint_type": "hanning", + "specint_shift": 4, + "specint_length": 9, + "normalize": True, + "preemphasis": True, + "legacy_scaling": False, + "without_samples": False, + "samples_options": { + "audio_format": "wav", + "dc_detection": False, + }, + "normalization_options": {}, + } + } + + +def get_log_mel_feature_extraction_args(): + + return { + "fb": { + "filterbank_options": { + "warping_function": "mel", + "filter_width": filter_width_from_channels(channels=80, warping_function="mel", f_max=8000), + "normalize": True, + "normalization_options": None, + "without_samples": False, + "samples_options": { + "audio_format": "wav", + "dc_detection": False, + }, + "fft_options": None, + "add_features_output": True, + "apply_log": True, + "add_epsilon": True, + } + } + } diff --git a/common/baselines/tedlium2/hybrid/baseline_config.py b/common/baselines/tedlium2/hybrid/baseline_config.py new file mode 100644 index 000000000..77df7617f --- /dev/null +++ b/common/baselines/tedlium2/hybrid/baseline_config.py @@ -0,0 +1,71 @@ +import copy +from sisyphus import gs, tk + +from i6_core.features import FilterbankJob + +from i6_experiments.common.setups.rasr.util import RasrSteps +from i6_experiments.common.setups.rasr.hybrid_system import HybridSystem +from i6_experiments.common.baselines.tedlium2.default_tools import RETURNN_RC_ROOT, RASR_BINARY_PATH + +from .data import get_corpus_data_inputs +from .baseline_args import get_log_mel_feature_extraction_args +from .nn_config.nn_args import get_nn_args + + +def run_gmm_system(): + from i6_experiments.common.baselines.tedlium2.gmm.baseline_config import ( + run_tedlium2_common_baseline, + ) + + system = run_tedlium2_common_baseline() + return system + + +def run_tedlium2_hybrid_baseline(): + gs.ALIAS_AND_OUTPUT_SUBDIR = "baselines/tedlium2/hybrid/baseline" + + gmm_system = run_gmm_system() + rasr_init_args = copy.deepcopy(gmm_system.rasr_init_args) + rasr_init_args.feature_extraction_args = get_log_mel_feature_extraction_args() + ( + nn_train_data_inputs, + nn_cv_data_inputs, + nn_devtrain_data_inputs, + nn_dev_data_inputs, + nn_test_data_inputs, + ) = get_corpus_data_inputs( + gmm_system, + rasr_init_args.feature_extraction_args["fb"], + FilterbankJob, + alias_prefix="experiments/tedlium2/hybrid/wei_baseline", + ) + # image only, so just python3 + returnn_exe = tk.Path("/usr/bin/python3", hash_overwrite="GENERIC_RETURNN_LAUNCHER") + blas_lib = tk.Path("/lib/x86_64-linux-gnu/liblapack.so.3") + blas_lib.hash_overwrite = "TEDLIUM2_DEFAULT_RASR_BINARY_PATH" + steps = RasrSteps() + steps.add_step("extract", rasr_init_args.feature_extraction_args) + gmm_system.run(steps) + nn_args = get_nn_args(num_epochs=160) + nn_steps = RasrSteps() + nn_steps.add_step("nn", nn_args) + + tedlium_nn_system = HybridSystem( + returnn_root=RETURNN_RC_ROOT, + returnn_python_exe=returnn_exe, + blas_lib=blas_lib, + rasr_arch="linux-x86_64-standard", + rasr_binary_path=RASR_BINARY_PATH, + ) + tedlium_nn_system.init_system( + rasr_init_args=rasr_init_args, + train_data=nn_train_data_inputs, + cv_data=nn_cv_data_inputs, + devtrain_data=nn_devtrain_data_inputs, + dev_data=nn_dev_data_inputs, + test_data=nn_test_data_inputs, + train_cv_pairing=[tuple(["train.train", "dev.cv"])], + ) + tedlium_nn_system.run(nn_steps) + + gs.ALIAS_AND_OUTPUT_SUBDIR = "" diff --git a/common/baselines/tedlium2/hybrid/data.py b/common/baselines/tedlium2/hybrid/data.py new file mode 100644 index 000000000..e012a7e1b --- /dev/null +++ b/common/baselines/tedlium2/hybrid/data.py @@ -0,0 +1,259 @@ +from typing import Optional, Dict, Any, Tuple, Callable +from sisyphus import tk + +from i6_core import corpus as corpus_recipe +from i6_core.returnn import ReturnnDumpHDFJob +from i6_core.features import FeatureExtractionJob + +from i6_experiments.common.datasets.tedlium2.constants import DURATIONS, NUM_SEGMENTS +from i6_experiments.common.setups.rasr.gmm_system import GmmSystem +from i6_experiments.common.setups.rasr.util import ( + HdfDataInput, + AllophoneLabeling, + ReturnnRasrDataInput, + ForcedAlignmentArgs, +) +from i6_experiments.common.datasets.tedlium2.lexicon import get_g2p_augmented_bliss_lexicon +from ..default_tools import RETURNN_EXE, RETURNN_RC_ROOT + + +def build_hdf_data_input( + features: tk.Path, + allophone_labeling: AllophoneLabeling, + alignments: tk.Path, + segment_list: Optional[tk.Path] = None, + alias_prefix: Optional[str] = None, + partition_epoch: int = 1, + acoustic_mixtures: Optional = None, + seq_ordering: str = "sorted", +) -> HdfDataInput: + """ + Dumps features and alignments from RASR into hdfs, to enable full RETURNN training + :param features: Feature bundle generated by the dump_features_for_hybrid_training function + :param allophone_labeling: Allophone labeling including silence_phoneme, allophones and state_tying + :param alignments: Target alignments generated from the pre-trained GMM + :param segment_list: segment list for the alignment dataset which will serve as seq_control dataset + :param alias_prefix: Prefix for the dump jobs + :param partition_epoch: Partition epoch for the alignment dataset, mainly relevant for training dataset + :param acoustic_mixtures: Acoustic mixture file from the GMM for prior calculation, most likely going to be replaced + :param seq_ordering: sequence ordering for the align dataset, usually sorted for dev/eval and laplace for train + :return: HdfDataInput with corresponding hdf datasets + """ + + feat_dataset = { + "class": "SprintCacheDataset", + "data": { + "data": { + "filename": features, + "data_type": "feat", + "allophone_labeling": { + "silence_phone": allophone_labeling.silence_phone, + "allophone_file": allophone_labeling.allophone_file, + "state_tying_file": allophone_labeling.state_tying_file, + }, + } + }, + "seq_list_filter_file": segment_list, + } + + feat_job = ReturnnDumpHDFJob(data=feat_dataset, returnn_python_exe=RETURNN_EXE, returnn_root=RETURNN_RC_ROOT) + if alias_prefix is not None: + feat_job.add_alias(alias_prefix + "/dump_features") + feat_hdf = feat_job.out_hdf + + align_dataset = { + "class": "SprintCacheDataset", + "data": { + "data": { + "filename": alignments, + "data_type": "align", + "allophone_labeling": { + "silence_phone": allophone_labeling.silence_phone, + "allophone_file": allophone_labeling.allophone_file, + "state_tying_file": allophone_labeling.state_tying_file, + }, + } + }, + "seq_list_filter_file": segment_list, + } + align_job = ReturnnDumpHDFJob(data=align_dataset, returnn_python_exe=RETURNN_EXE, returnn_root=RETURNN_RC_ROOT) + if alias_prefix is not None: + align_job.add_alias(alias_prefix + "/dump_alignments") + align_hdf = align_job.out_hdf + + return HdfDataInput( + features=feat_hdf, + alignments=align_hdf, + partition_epoch=partition_epoch, + acoustic_mixtures=acoustic_mixtures, + seq_ordering=seq_ordering, + segment_file=segment_list, + ) + + +def dump_features_for_hybrid_training( + gmm_system: GmmSystem, + feature_extraction_args: Dict[str, Any], + feature_extraction_class: Callable[[Any, ...], FeatureExtractionJob], +) -> Tuple[tk.Path, tk.Path, tk.Path]: + """ + + :param gmm_system: GMM system to get corpora from + :param feature_extraction_args: Args for the feature extraction + :param feature_extraction_class: Feature extraction class/job to be used for extraction + :return: path to the train cv and devtrain features + """ + features = {} + for name in ["nn-train", "nn-cv", "nn-devtrain"]: + features[name] = list( + feature_extraction_class(gmm_system.crp[name], **feature_extraction_args).out_feature_bundle.values() + )[0] + return features["nn-train"], features["nn-cv"], features["nn-devtrain"] + + +def get_corpus_data_inputs( + gmm_system: GmmSystem, + feature_extraction_args: Dict[str, Any], + feature_extraction_class: Callable[[Any], FeatureExtractionJob], + alias_prefix: Optional[str] = None, +) -> Tuple[ + Dict[str, HdfDataInput], + Dict[str, HdfDataInput], + Dict[str, HdfDataInput], + Dict[str, ReturnnRasrDataInput], + Dict[str, ReturnnRasrDataInput], +]: + """ + Builds the data inputs for the hybrid system, inlcuding 3 training hdf pairs with align and feature dataset for + full returnn training + :param gmm_system: Pre-trained GMM-system to derive the hybrid setup from + :param feature_extraction_args: Args for the feature extraction of the hybrid features (might be different from GMM) + :param feature_extraction_class: Feature extraction class/job to be used for extraction + :param alias_prefix: Prefix for naming of experiments + :return: HdfDataInputs for the train sets and ReturnnRasrDataInputs for the dev and train sets + """ + + train_corpus_path = gmm_system.corpora["train"].corpus_file + cv_corpus_path = gmm_system.corpora["dev"].corpus_file + + cv_corpus_path = corpus_recipe.FilterCorpusRemoveUnknownWordSegmentsJob( + bliss_corpus=cv_corpus_path, bliss_lexicon=get_g2p_augmented_bliss_lexicon(), all_unknown=False + ).out_corpus + + total_train_num_segments = NUM_SEGMENTS["train"] + + all_train_segments = corpus_recipe.SegmentCorpusJob(train_corpus_path, 1).out_single_segment_files[1] + cv_segments = corpus_recipe.SegmentCorpusJob(cv_corpus_path, 1).out_single_segment_files[1] + + dev_train_size = 500 / total_train_num_segments + splitted_train_segments_job = corpus_recipe.ShuffleAndSplitSegmentsJob( + all_train_segments, + {"devtrain": dev_train_size, "unused": 1 - dev_train_size}, + ) + devtrain_segments = splitted_train_segments_job.out_segments["devtrain"] + + # ******************** NN Init ******************** + + gmm_system.add_overlay("train", "nn-train") + gmm_system.crp["nn-train"].segment_path = all_train_segments + gmm_system.crp["nn-train"].concurrent = 1 + gmm_system.crp["nn-train"].corpus_duration = DURATIONS["train"] + + gmm_system.add_overlay("dev", "nn-cv") + gmm_system.crp["nn-cv"].segment_path = cv_segments + gmm_system.crp["nn-cv"].concurrent = 1 + gmm_system.crp["nn-cv"].corpus_duration = DURATIONS["dev"] + + gmm_system.add_overlay("train", "nn-devtrain") + gmm_system.crp["nn-devtrain"].segment_path = devtrain_segments + gmm_system.crp["nn-devtrain"].concurrent = 1 + gmm_system.crp["nn-devtrain"].corpus_duration = DURATIONS["train"] * dev_train_size + + # ******************** extract features ******************** + + train_features, cv_features, devtrain_features = dump_features_for_hybrid_training( + gmm_system, + feature_extraction_args, + feature_extraction_class, + ) + + allophone_labeling = AllophoneLabeling( + silence_phone="[SILENCE]", + allophone_file=gmm_system.allophone_files["train"], + state_tying_file=gmm_system.jobs["train"]["state_tying_gmm_out"].out_state_tying, + ) + + forced_align_args = ForcedAlignmentArgs( + name="nn-cv", + target_corpus_keys=["nn-cv"], + flow="uncached_mfcc+context+lda+vtln+cmllr", + feature_scorer="train_vtln+sat", + scorer_index=-1, + bliss_lexicon={ + "filename": get_g2p_augmented_bliss_lexicon(), + "normalize_pronunciation": False, + }, + dump_alignment=True, + ) + gmm_system.run_forced_align_step(forced_align_args) + + nn_train_data = build_hdf_data_input( + features=train_features, + alignments=gmm_system.outputs["train"]["final"].as_returnn_rasr_data_input().alignments.alternatives["bundle"], + allophone_labeling=allophone_labeling, + alias_prefix=alias_prefix + "/nn_train_data", + partition_epoch=5, + acoustic_mixtures=None, + seq_ordering="laplace:.1000", + ) + tk.register_output(f"{alias_prefix}/nn_train_data/features", nn_train_data.features) + tk.register_output(f"{alias_prefix}/nn_train_data/alignments", nn_train_data.alignments) + nn_devtrain_data = build_hdf_data_input( + features=devtrain_features, + alignments=gmm_system.outputs["train"]["final"].as_returnn_rasr_data_input().alignments.alternatives["bundle"], + allophone_labeling=allophone_labeling, + segment_list=devtrain_segments, + alias_prefix=alias_prefix + "/nn_devtrain_data", + partition_epoch=1, + seq_ordering="sorted", + ) + tk.register_output(f"{alias_prefix}/nn_devtrain_data/features", nn_devtrain_data.features) + tk.register_output(f"{alias_prefix}/nn_devtrain_data/alignments", nn_devtrain_data.alignments) + nn_cv_data = build_hdf_data_input( + features=cv_features, + alignments=gmm_system.alignments["nn-cv_forced-align"]["nn-cv"].alternatives["bundle"], + allophone_labeling=allophone_labeling, + alias_prefix=alias_prefix + "/nn_cv_data", + partition_epoch=1, + seq_ordering="sorted", + ) + tk.register_output(f"{alias_prefix}/nn_cv_data/features", nn_cv_data.features) + tk.register_output(f"{alias_prefix}/nn_cv_data/alignments", nn_cv_data.alignments) + + nn_train_data_inputs = { + "train.train": nn_train_data, + } + nn_devtrain_data_inputs = { + "train.devtrain": nn_devtrain_data, + } + + nn_cv_data_inputs = { + "dev.cv": nn_cv_data, + } + + nn_dev_data_inputs = { + "dev": gmm_system.outputs["dev"]["final"].as_returnn_rasr_data_input(), + } + nn_test_data_inputs = { + # "test": gmm_system.outputs["test"][ + # "final" + # ].as_returnn_rasr_data_input(), + } + + return ( + nn_train_data_inputs, + nn_cv_data_inputs, + nn_devtrain_data_inputs, + nn_dev_data_inputs, + nn_test_data_inputs, + ) diff --git a/common/baselines/tedlium2/hybrid/nn_config/experiment.py b/common/baselines/tedlium2/hybrid/nn_config/experiment.py new file mode 100644 index 000000000..b79581ec5 --- /dev/null +++ b/common/baselines/tedlium2/hybrid/nn_config/experiment.py @@ -0,0 +1,17 @@ +from .helper import get_network +from .helper import make_nn_config + + +def get_baseline_config(specaug=False): + network = get_network(spec_augment=specaug) + nn_config = make_nn_config(network) + nn_config["extern_data"] = { + "data": { + "dim": 80, + "shape": (None, 80), + "available_for_inference": True, + }, # input: 80-dimensional logmel features + "classes": {"dim": 9001, "shape": (None,), "available_for_inference": True, "sparse": True, "dtype": "int16"}, + } + + return nn_config diff --git a/common/baselines/tedlium2/hybrid/nn_config/helper.py b/common/baselines/tedlium2/hybrid/nn_config/helper.py new file mode 100644 index 000000000..3f8dec4e5 --- /dev/null +++ b/common/baselines/tedlium2/hybrid/nn_config/helper.py @@ -0,0 +1,86 @@ +import copy + + +from .nn_setup import build_encoder_network, add_output_layer, add_specaug_source_layer + + +default_nn_config_args = { # batching # + "batch_size": 10000, + "max_seqs": 128, + "chunking": "64:32", # better than 50:25 + "batching": "random", + "min_seq_length": {"classes": 1}, + # optimization # + #'nadam': True, + "learning_rate": 0.0009, + "gradient_clip": 0, + "gradient_noise": 0.1, # together with l2 and dropout for overfit + # Note: (default 1e-8) likely not too much impact + #'optimizer_epsilon': 1e-8, + "optimizer": {"class": "nadam", "epsilon": 1e-08}, + # let it stop and adjust in time + # Note: for inf or nan, sth. is too big (e.g. lr warm up) + # 'stop_on_nonfinite_train_score' : False, + "learning_rate_control": "newbob_multi_epoch", + "newbob_multi_num_epochs": 5, + "newbob_multi_update_interval": 1, + "newbob_learning_rate_decay": 0.9, + # 'newbob_relative_error_threshold' : -0.02, # default -0.01 + # 'min_learning_rate' : 1e-5 # + "learning_rate_control_relative_error_relative_lr": True, + "learning_rate_control_min_num_epochs_per_new_lr": 3, + "start_epoch": "auto", + "start_batch": "auto", + "use_tensorflow": True, + "update_on_device": True, + "multiprocessing": True, + "cache_size": "0", + "truncation": -1, + "window": 1, +} + + +def get_network(num_layers=6, layer_size=512, spec_augment=False, **kwargs): + + lstm_args = { + "num_layers": num_layers, + "size": layer_size, + "l2": 0.01, + "dropout": 0.1, + "bidirectional": True, + "unit": "nativelstm2", + } + network, from_list = build_encoder_network(**lstm_args) + + output_args = { + "loss": "ce", + "loss_opts": { # less weight on loss of easy samples (larger p) + "focal_loss_factor": 2.0, + }, + } + + network = add_output_layer(network, from_list, **output_args) + + if spec_augment: + network, from_list2 = add_specaug_source_layer(network) + + return copy.deepcopy(network) + + +def make_nn_config(network, nn_config_args=default_nn_config_args, **kwargs): + + nn_config = copy.deepcopy(nn_config_args) + nn_config["network"] = network + + # common training settings + optimizer = kwargs.pop("optimizer", None) + if optimizer is not None and not optimizer == "nadam": + del nn_config["nadam"] + nn_config[optimizer] = True + if kwargs.pop("no_pretrain", False): + del nn_config["pretrain"] + if kwargs.pop("no_chunking", False): + del nn_config["chunking"] + # Note: whatever left ! + nn_config.update(kwargs) + return nn_config diff --git a/common/baselines/tedlium2/hybrid/nn_config/nn_args.py b/common/baselines/tedlium2/hybrid/nn_config/nn_args.py new file mode 100644 index 000000000..989789102 --- /dev/null +++ b/common/baselines/tedlium2/hybrid/nn_config/nn_args.py @@ -0,0 +1,90 @@ +import copy + +from .experiment import get_baseline_config +from .nn_setup import get_spec_augment_mask_python +from i6_core.returnn.config import ReturnnConfig +from i6_experiments.common.setups.rasr.util import HybridArgs, ReturnnTrainingJobArgs +from copy import deepcopy + + +def get_nn_args(num_epochs=125): + + base_config = get_baseline_config() + returnn_config = ReturnnConfig(config=base_config) + + spec_augment_args = { + "max_time_num": 3, + "max_time": 10, + "max_feature_num": 5, + "max_feature": 18, + "conservatvie_step": 2000, + } + specaug = get_spec_augment_mask_python(**spec_augment_args) + specaug_config = get_baseline_config(specaug=True) + spec_cfg = ReturnnConfig(config=copy.deepcopy(specaug_config), python_epilog=specaug) + + configs = { + "base_config": returnn_config, + "specaug_config": spec_cfg, + } + + # change softmax to log softmax for hybrid + recog_configs = deepcopy(configs) + for config_name in recog_configs: + recog_configs[config_name].config["network"]["output"]["class"] = "log_softmax" + recog_configs[config_name].config["network"]["output"]["class"] = "linear" + recog_configs[config_name].config["network"]["output"]["activation"] = "log_softmax" + + # arguments for ReturnnTraining for now fixed + training_args = ReturnnTrainingJobArgs( + num_epochs=num_epochs, + log_verbosity=5, + save_interval=1, + keep_epochs=None, + time_rqmt=168, + mem_rqmt=8, + cpu_rqmt=3, + ) + + recognition_args = { + "dev": { + "epochs": [num_epochs], + "feature_flow_key": "fb", + "prior_scales": [0.7, 0.8, 0.9], + "pronunciation_scales": [0.0], + "lm_scales": [10.0, 7.5, 5.0], + "lm_lookahead": True, + "lookahead_options": None, + "create_lattice": True, + "eval_single_best": True, + "eval_best_in_lattice": True, + "search_parameters": { + "beam-pruning": 14.0, + "beam-pruning-limit": 100000, + "word-end-pruning": 0.5, + "word-end-pruning-limit": 15000, + }, + "lattice_to_ctm_kwargs": { + "fill_empty_segments": True, + "best_path_algo": "bellman-ford", + }, + "optimize_am_lm_scale": True, + "rtf": 50, + "mem": 8, + "lmgc_mem": 16, + "cpu": 4, + "parallelize_conversion": True, + "use_epoch_for_compile": True, + "native_ops": ["NativeLstm2"], + }, + } + + nn_args = HybridArgs( + returnn_training_configs=configs, + returnn_recognition_configs=recog_configs, + training_args=training_args, + recognition_args=recognition_args, + test_recognition_args=None, + ) + + return nn_args diff --git a/common/baselines/tedlium2/hybrid/nn_config/nn_setup.py b/common/baselines/tedlium2/hybrid/nn_config/nn_setup.py new file mode 100644 index 000000000..3340c5d6a --- /dev/null +++ b/common/baselines/tedlium2/hybrid/nn_config/nn_setup.py @@ -0,0 +1,854 @@ +## lm_config +import os +import copy + +# ------------------------------ Recipes ------------------------------ +from sisyphus import tk + +Path = tk.Path + +# only used for seq-training/full_sum training +# import recipe.crnn as crnn + +### construct RETURNN network layers on demand of training ### + + +def make_network(): + network = dict() + fromList = ["data"] + return network, fromList + + +def add_loss_to_layer(network, name, loss, loss_opts=None, target=None, **kwargs): + assert loss is not None + network[name]["loss"] = loss + if loss_opts: + network[name]["loss_opts"] = loss_opts + if target is not None: + network[name]["target"] = target + return network + + +def add_specaug_source_layer(network, name="source", nextLayers=["fwd_lstm_1", "bwd_lstm_1"]): + network2 = copy.deepcopy(network) + network2[name] = { + "class": "eval", + "eval": "self.network.get_config().typed_value('transform')(source(0, as_data=True), network=self.network)", + } + for layer in nextLayers: + if not layer in network2: + continue + network2[layer]["from"] = [name] + return network2, name + + +def add_linear_layer( + network: object, + name: object, + fromList: object, + size: object, + l2: object = 0.01, + dropout: object = None, + bias: object = None, + activation: object = None, + **kwargs, +) -> object: + network[name] = {"class": "linear", "n_out": size, "from": fromList, "activation": activation} + if l2 is not None: + network[name]["L2"] = l2 + if dropout is not None: + network[name]["dropout"] = dropout + # bias is default true in RETURNN + if bias is not None: + network[name]["with_bias"] = bias + if kwargs.get("random_norm_init", False): + network[name]["forward_weights_init"] = "random_normal_initializer(mean=0.0, stddev=0.1)" + if kwargs.get("initial", None) is not None: + network[name]["initial_output"] = kwargs.get("initial", None) + if kwargs.get("loss", None) is not None: + network = add_loss_to_layer(network, name, **kwargs) + if kwargs.get("reuse_params", None) is not None: + network[name]["reuse_params"] = kwargs.get("reuse_params", None) + if not kwargs.get("trainable", True): + network[name]["trainable"] = False + if kwargs.get("out_type", None) is not None: + network[name]["out_type"] = kwargs.get("out_type", None) + # Note: this is not in the master RETURNN branch + if kwargs.get("safe_embedding", False): + network[name]["safe_embedding"] = True # 0-vectors for out-of-range ids (only for embedding) + if kwargs.get("validate_indices", False): + network[name]["validate_indices"] = True # round out-of-range ids to 0 (only for embedding) + return network, name + + +def add_activation_layer(network, name, fromList, activation, **kwargs): + network[name] = {"class": "activation", "from": fromList, "activation": activation} + if kwargs.get("loss", None) is not None: + network = add_loss_to_layer(network, name, **kwargs) + return network, name + + +def add_lstm_layer( + network, name, fromList, size, l2=0.01, dropout=0.1, bidirectional=True, unit="nativelstm2", **kwargs +): + if bidirectional: + layers = [("fwd_" + name, 1), ("bwd_" + name, -1)] + else: + layers = [(name, 1)] + + names = [] + for n, d in layers: + network[n] = { + "class": "rec", + "unit": unit, + "n_out": size, + "from": fromList, + "direction": d, + "dropout": dropout, + "L2": l2, + } + if kwargs.get("drop_connect", None) is not None: + network[n]["unit_opts"] = {"rec_weight_dropout": kwargs.get("drop_connect", None)} + if kwargs.get("random_norm_init", False): + network[n]["forward_weights_init"] = "random_normal_initializer(mean=0.0, stddev=0.1)" + network[n]["recurrent_weights_init"] = "random_normal_initializer(mean=0.0, stddev=0.1)" + network[n]["bias_init"] = "random_normal_initializer(mean=0.0, stddev=0.1)" + if not kwargs.get("trainable", True): + network[n]["trainable"] = False + names.append(n) + + if len(names) == 1: + names = names[0] + return network, names + + +def add_constant_layer(network, name, value, dtype="int32", with_batch_dim=True, **kwargs): + network[name] = {"class": "constant", "value": value, "dtype": dtype, "with_batch_dim": with_batch_dim} + if kwargs.get("out_type", {}): + network[name]["out_type"] = kwargs.get("out_type", {}) + if kwargs.get("initial", None) is not None: + network[name]["initial_output"] = kwargs.get("initial", None) + return network, name + + +def add_cast_layer(network, name, fromList, dtype="float32"): + network[name] = {"class": "cast", "from": fromList, "dtype": dtype} + return network, name + + +def add_expand_dim_layer(network, name, fromList, axis, out_type=None): + network[name] = {"class": "expand_dims", "from": fromList, "axis": 2} # if int, then automatically batch major + if out_type is not None: + network[name]["out_type"] = out_type + return network, name + + +def add_copy_layer(network, name, fromList, initial=None, loss=None, **kwargs): + network[name] = {"class": "copy", "from": fromList} + if initial is not None: + network[name]["initial_output"] = initial + if loss is not None: + network = add_loss_to_layer(network, name, loss, **kwargs) + if kwargs.get("is_output", False): + network[name]["is_output_layer"] = True + if kwargs.get("dropout", None) is not None: + network[name]["dropout"] = kwargs.get("dropout", None) + return network, name + + +def add_compare_layer(network, name, fromList, value=None, kind="not_equal", initial=None): + network[name] = {"class": "compare", "from": fromList, "kind": kind} + if value is not None: + network[name]["value"] = value + if initial is not None: + network[name]["initial_output"] = initial + return network, name + + +def make_subnet(fromList, net): + subnet = {"class": "subnetwork", "from": fromList, "subnetwork": net} + return subnet + + +# masked computation +def add_mask_layer( + network: object, name: object, fromList: object, mask: object, unit: object = {"class": "copy"}, **kwargs: object +) -> object: + network[name] = { + "class": "masked_computation", + "from": fromList, + "mask": mask, + "unit": unit, + } + # more likely to be used in training where input is already masked elsewhere: directly use + if kwargs.get("masked_from", None) is not None: + network[name]["masked_from"] = kwargs.get("masked_from", None) + # heuristics likely not needed anymore, use pad layer to achieve the same + if kwargs.get("initial", None) is not None: + network[name]["unit"]["initial_output"] = kwargs.get("initial", None) + if kwargs.get("keep_last_for_prev", False): + network[name]["keep_last_for_prev"] = True + if kwargs.get("is_output", False): + network[name]["is_output_layer"] = True + return network, name + + +def add_unmask_layer(network, name, fromList, mask, **kwargs): + network[name] = {"class": "unmask", "from": fromList, "mask": mask} + # do not use initial_output but directly the 1st frame of input for the first Fs + if kwargs.get("skip_initial", True): + network[name]["skip_initial"] = True + return network, name + + +def add_padding_layer( + network, name, fromList, axes="T", padding=(0, 1), value=0, mode="constant", n_out=None, **kwargs +): + network[name] = {"class": "pad", "from": fromList, "axes": axes, "padding": padding, "value": value, "mode": mode} + if n_out is not None: + network[name]["n_out"] = n_out + if kwargs.get("is_output", False): + network[name]["is_output_layer"] = True + if kwargs.get("initial", None) is not None: + network[name]["initial_output"] = kwargs.get("initial", None) + if kwargs.get("out_type", None) is not None: + network[name]["out_type"] = kwargs.get("out_type", None) + return network, name + + +def add_time_postfix_layer(network, name, fromList, postfix, repeat=1): + network[name] = {"class": "postfix_in_time", "from": fromList, "postfix": postfix, "repeat": repeat} + return network, name + + +def add_axis_range_layer(network, name, fromList, axis="T", unbroadcast=True): + network[name] = {"class": "range_in_axis", "from": fromList, "axis": axis, "unbroadcast": unbroadcast} + return network, name + + +def add_shift_layer(network, name, fromList, axis="T", amount=1, pad=True, **kwargs): + network[name] = {"class": "shift_axis", "from": fromList, "axis": axis, "amount": amount, "pad": pad} + if kwargs.get("adjust_size", None) is not None: + network[name]["adjust_size_info"] = kwargs.get("adjust_size", None) + if kwargs.get("initial", None) is not None: + network[name]["initial_output"] = kwargs.get("initial", None) + return network, name + + +def add_seq_len_mask_layer(network, name, fromList, axis="T", mask_value=0): + network[name] = {"class": "seq_len_mask", "from": fromList, "axis": axis, "mask_value": mask_value} + return network, name + + +def add_pool_layer(network, name, fromList, mode="max", pool_size=(2,), padding="same", **kwargs): + network[name] = { + "class": "pool", + "mode": mode, + "padding": padding, + "pool_size": pool_size, + "from": fromList, + "trainable": False, + } + return network, name + + +def add_reinterpret_data_layer(network, name, fromList, size_base=None, **kwargs): + network[name] = {"class": "reinterpret_data", "from": fromList} + if kwargs.get("loss", None) is not None: + network = add_loss_to_layer(network, name, **kwargs) + if size_base is not None: + network[name]["size_base"] = size_base + if kwargs.get("enforce_time_major", False): + network[name]["enforce_time_major"] = True + if kwargs.get("set_sparse", None) is not None: + network[name]["set_sparse"] = kwargs.get("set_sparse", None) + if kwargs.get("set_sparse_dim", None) is not None: + network[name]["set_sparse_dim"] = kwargs.get("set_sparse_dim", None) + if kwargs.get("is_output", False): + network[name]["is_output_layer"] = True + return network, name + + +def add_window_layer(network, name, fromList, winSize, winLeft, **kwargs): + network[name] = { + "class": "window", + "from": fromList, + "window_size": winSize, + "window_left": winLeft + # default along time axis and 0 padding (also works inside rec loop) + } + return network, name + + +def add_merge_dim_layer(network, name, fromList, axes="except_time", **kwargs): + network[name] = {"class": "merge_dims", "from": fromList, "axes": axes} + return network, name + + +def add_split_dim_layer(network, name, fromList, axis, dims, **kwargs): + network[name] = {"class": "split_dims", "from": fromList, "axis": axis, "dims": dims} + return network, name + + +def add_slice_layer(network, name, fromList, axis="F", start=None, end=None, step=None): + network[name] = { + "class": "slice", + "from": fromList, + "axis": axis, + "slice_start": start, + "slice_end": end, + "slice_step": step, + } + return network, name + + +def add_squeeze_layer(network, name, fromList, axis, enforce_batch_dim_axis=None): + network[name] = {"class": "squeeze", "from": fromList, "axis": axis} + if enforce_batch_dim_axis is not None: + network[name]["enforce_batch_dim_axis"] = enforce_batch_dim_axis + return network, name + + +def add_layer_norm_layer(network, name, fromList): + network[name] = {"class": "layer_norm", "from": fromList} + return network, name + + +def add_batch_norm_layer(network, name, fromList, **kwargs): + network[name] = {"class": "batch_norm", "from": fromList} + # RETURNN defaults wrong + if kwargs.get("fix_settings", False): + network[name].update( + { + "momentum": 0.1, + "epsilon": 1e-5, + # otherwise eval may be batch-size and utterance-order dependent ! + "update_sample_only_in_training": True, + "delay_sample_update": True, + } + ) + # freeze batch norm running average in training: consistent with testing + if kwargs.get("freeze_average", False): + network[name]["momentum"] = 0.0 + network[name]["use_sample"] = 1.0 + return network, name + + +# eval layer is a also special case of combine layer, but we distinguish them explicitly here +# and only restricted to the 'kind' usage +def add_combine_layer(network, name, fromList, kind="add", **kwargs): + network[name] = {"class": "combine", "from": fromList, "kind": kind} + if kwargs.get("activation", None) is not None: + network[name]["activation"] = kwargs.get("activation", None) + if kwargs.get("with_bias", None) is not None: + network[name]["with_bias"] = kwargs.get("with_bias", None) + if kwargs.get("n_out", None) is not None: + network[name]["n_out"] = kwargs.get("n_out", None) + if kwargs.get("is_output", False): + network[name]["is_output_layer"] = True + return network, name + + +# Note: RETURNN source(i, auto_convert=True, enforce_batch_major=False, as_data=False) +def add_eval_layer(network, name, fromList, eval_str, **kwargs): + network[name] = {"class": "eval", "from": fromList, "eval": eval_str} + if kwargs.get("loss", None) is not None: + network = add_loss_to_layer(network, name, **kwargs) + if kwargs.get("initial", None) is not None: + network[name]["initial_output"] = kwargs.get("initial", None) + if kwargs.get("n_out", None) is not None: + network[name]["n_out"] = kwargs.get("n_out", None) + if kwargs.get("out_type", None) is not None: + network[name]["out_type"] = kwargs.get("out_type", None) + return network, name + + +def add_variable_layer(network, name, shape, **kwargs): + network[name] = {"class": "variable", "shape": shape} + return network, name + + +# generic attention +def add_attention_layer(network, name, base, weights, **kwargs): + network[name] = {"class": "generic_attention", "base": base, "weights": weights} + return network, name + + +def add_spatial_softmax_layer(network, name, fromList, **kwargs): + network[name] = {"class": "softmax_over_spatial", "from": fromList} + return network, name + + +def add_rel_pos_encoding_layer(network, name, fromList, n_out, clipping=64, **kwargs): + network[name] = {"class": "relative_positional_encoding", "from": fromList, "n_out": n_out, "clipping": clipping} + return network, name + + +def add_self_attention_layer( + network, name, fromList, n_out, num_heads, total_key_dim, key_shift=None, attention_dropout=None, **kwargs +): + network[name] = { + "class": "self_attention", + "from": fromList, + "n_out": n_out, + "num_heads": num_heads, + "total_key_dim": total_key_dim, + } + if key_shift is not None: + network[name]["key_shift"] = key_shift + if attention_dropout is not None: + network[name]["attention_dropout"] = attention_dropout + return network, name + + +def add_conv_layer( + network, name, fromList, n_out, filter_size, padding="VALID", l2=0.01, bias=True, activation=None, **kwargs +): + network[name] = { + "class": "conv", + "from": fromList, + "n_out": n_out, + "filter_size": filter_size, + "padding": padding, + "with_bias": bias, + "activation": activation, + } + if l2 is not None: + network[name]["L2"] = l2 + if kwargs.get("strides", None) is not None: + network[name]["strides"] = kwargs.get("strides", None) + if kwargs.get("groups", None) is not None: + network[name]["groups"] = kwargs.get("groups", None) + if not kwargs.get("trainable", True): + network[name]["trainable"] = False + return network, name + + +def add_gating_layer(network, name, fromList, activation=None, gate_activation="sigmoid", **kwargs): + network[name] = {"class": "gating", "from": fromList, "activation": activation, "gate_activation": gate_activation} + return network, name + + +def add_reduce_layer(network, name, fromList, mode="mean", axes="T", keep_dims=False, **kwargs): + network[name] = {"class": "reduce", "from": fromList, "mode": mode, "axes": axes, "keep_dims": keep_dims} + return network, name + + +def add_reduce_out_layer(network, name, fromList, mode="max", num_pieces=2, **kwargs): + network[name] = {"class": "reduce_out", "from": fromList, "mode": mode, "num_pieces": num_pieces} + return network, name + + +# Convolution block +def add_conv_block( + network, fromList, conv_layers, conv_filter, conv_size, pool_size=None, name_prefix="conv", **kwargs +): + network, fromList = add_split_dim_layer(network, "conv_source", fromList, axis="F", dims=(-1, 1)) + for idx in range(conv_layers): + name = name_prefix + "_" + str(idx + 1) + network, fromList = add_conv_layer(network, name, fromList, conv_size, conv_filter, padding="same", **kwargs) + if pool_size is not None: + name += "_pool" + if isinstance(pool_size, list): + assert idx < len(pool_size) + pool = pool_size[idx] + else: + pool = pool_size + assert isinstance(pool, tuple) + if any([p > 1 for p in pool]): + network, fromList = add_pool_layer(network, name, fromList, pool_size=pool) + network, fromList = add_merge_dim_layer(network, "conv_merged", fromList, axes="static") + return network, fromList + + +# BLSTM encoder with optional max-pool subsampling +def build_encoder_network(num_layers=6, size=512, max_pool=[], **kwargs): + network, fromList = make_network() + # Convolution layers (no subsampling) + if kwargs.pop("initial_convolution", False): + # TODO no pooling on feature dim ? (correlation is already low) + conv_layers, conv_filter, conv_size, pool = kwargs.pop("convolution_layers", (2, (3, 3), 32, (1, 2))) + network, fromList = add_conv_block( + network, fromList, conv_layers, conv_filter, conv_size, pool_size=pool, **kwargs + ) + # BLSTM layers + for idx in range(num_layers): + name = "lstm_" + str(idx + 1) + network, fromList = add_lstm_layer(network, name, fromList, size, **kwargs) + if max_pool and idx < len(max_pool) and max_pool[idx] > 1: + name = "max_pool_" + str(idx + 1) + network, fromList = add_pool_layer(network, name, fromList, pool_size=(max_pool[idx],)) + return network, fromList + + +# Conformer encoder TODO freeze encoder: pass trainable False +def add_conformer_block(network, name, fromList, size, dropout, l2, **kwargs): + # feed-forward module + def add_ff_module(net, n, fin): + net, fout = add_layer_norm_layer(net, n + "_ln", fin) + net, fout = add_linear_layer(net, n + "_linear_swish", fout, size * 4, l2=l2, activation="swish") + net, fout = add_linear_layer(net, n + "_dropout_linear", fout, size, l2=l2, dropout=dropout) + net, fout = add_copy_layer(net, n + "_dropout", fout, dropout=dropout) + net, fout = add_eval_layer(net, n + "_half_res_add", [fout, fin], "0.5 * source(0) + source(1)") + return net, fout + + # multi-head self-attention module + def add_mhsa_module(net, n, fin, heads, posEncSize, posEncClip, posEnc=True): + net, fout = add_layer_norm_layer(net, n + "_ln", fin) + if posEnc: + net, fpos = add_rel_pos_encoding_layer(net, n + "_relpos_encoding", fout, posEncSize, clipping=posEncClip) + else: + fpos = None + net, fout = add_self_attention_layer( + net, n + "_self_attention", fout, size, heads, size, key_shift=fpos, attention_dropout=dropout + ) + net, fout = add_linear_layer(net, n + "_att_linear", fout, size, l2=l2, bias=False) + net, fout = add_copy_layer(net, n + "_dropout", fout, dropout=dropout) + net, fout = add_combine_layer(net, n + "_res_add", [fout, fin]) + return net, fout + + # convolution module + def add_conv_module(net, n, fin, filterSize, bnFix, bnFreeze, bn2ln): + net, fout = add_layer_norm_layer(net, n + "_ln", fin) + # glu weights merged into pointwise conv, i.e. linear layer + net, fout = add_linear_layer(net, n + "_pointwise_conv_1", fout, size * 2, l2=l2) + net, fout = add_gating_layer(net, n + "_glu", fout) + net, fout = add_conv_layer( + net, n + "_depthwise_conv", fout, size, filterSize, padding="same", l2=l2, groups=size + ) + if bn2ln: + net, fout = add_layer_norm_layer(net, n + "_bn2ln", fout) + else: + net, fout = add_batch_norm_layer(net, n + "_bn", fout, fix_settings=bnFix, freeze_average=bnFreeze) + net, fout = add_activation_layer(net, n + "_swish", fout, "swish") + net, fout = add_linear_layer(net, n + "_pointwise_conv_2", fout, size, l2=l2) + net, fout = add_copy_layer(net, n + "_dropout", fout, dropout=dropout) + net, fout = add_combine_layer(net, n + "_res_add", [fout, fin]) + return net, fout + + network, fList = add_ff_module(network, name + "_ffmod_1", fromList) + + mhsa_args = { + "heads": kwargs.get("num_att_heads", 8), + "posEncSize": kwargs.get("pos_enc_size", 64), + "posEncClip": kwargs.get("pos_enc_clip", 64), # default clipping 16 in RETURNN + "posEnc": kwargs.get("pos_encoding", True), + } + conv_args = { + "filterSize": kwargs.get("conv_filter_size", (32,)), + "bnFix": kwargs.get("batch_norm_fix", False), + "bnFreeze": kwargs.get("batch_norm_freeze", False), + "bn2ln": kwargs.get("batch_norm_to_layer_norm", False), + } + if kwargs.get("switch_conv_mhsa_module", False): + network, fList = add_conv_module(network, name + "_conv_mod", fList, **conv_args) + network, fList = add_mhsa_module(network, name + "_mhsa_mod", fList, **mhsa_args) + else: + network, fList = add_mhsa_module(network, name + "_mhsa_mod", fList, **mhsa_args) + network, fList = add_conv_module(network, name + "_conv_mod", fList, **conv_args) + + network, fList = add_ff_module(network, name + "_ffmod_2", fList) + network, fList = add_layer_norm_layer(network, name + "_output", fList) + return network, fList + + +def build_conformer_encoder(num_blocks=12, size=512, dropout=0.1, l2=0.0001, max_pool=[], **kwargs): + network, fromList = make_network() + # Input block + if kwargs.get("initial_convolution", True): + # vgg conv with subsampling 4 + if kwargs.get("vgg_conv", True): + network, fromList = add_conv_block( + network, fromList, 1, (3, 3), 32, pool_size=(1, 2), activation="swish", **kwargs + ) + stride1, stride2 = kwargs.get("vgg_conv_strides", (2, 2)) + network, fList = add_conv_layer( + network, + "conv_2", + network[fromList]["from"], + 64, + (3, 3), + padding="same", + strides=(stride1, 1), + activation="swish", + **kwargs, + ) + network, fList = add_conv_layer( + network, "conv_3", fList, 64, (3, 3), padding="same", strides=(stride2, 1), activation="swish", **kwargs + ) + network[fromList]["from"] = fList + elif kwargs.get("stride_subsampling", False): + conv_layers, conv_filter, conv_size, strides = kwargs.pop("convolution_layers", (2, (3, 3), 32, [2, 2])) + network, fromList = add_conv_block( + network, fromList, conv_layers, conv_filter, conv_size, strides=strides, **kwargs + ) + else: # max_pool subsampling + conv_layers, conv_filter, conv_size, pool = kwargs.pop("convolution_layers", (2, (3, 3), 32, (1, 2))) + network, fromList = add_conv_block( + network, fromList, conv_layers, conv_filter, conv_size, pool_size=pool, **kwargs + ) + assert not max_pool + elif kwargs.get("initial_blstm", False): # BLSTM with subsampling 4 + layers, uniSize, pool = kwargs.pop("blstm_layers", (2, 512, [2, 2])) + network, fromList = build_encoder_network( + num_layers=layers, size=uniSize, max_pool=pool, dropout=dropout, l2=l2, **kwargs + ) + assert not max_pool + network, fromList = add_linear_layer(network, "input_linear", fromList, size, l2=l2, bias=False) + network, fromList = add_copy_layer(network, "input_dropout", fromList, dropout=dropout) + + # Conformer blocks + for idx in range(num_blocks): + name = "conformer_" + str(idx + 1) + network, fromList = add_conformer_block(network, name, fromList, size, dropout, l2, **kwargs) + # also allow subsampling between conformer blocks + if max_pool and idx < len(max_pool) and max_pool[idx] > 1: + name += "_max_pool" + network, fromList = add_pool_layer(network, name, fromList, pool_size=(max_pool[idx],)) + return network, fromList + + +# -- output and loss -- +def add_loss_layer(network, name, fromList, loss="ce", **kwargs): + network[name] = {"class": "loss", "from": fromList, "loss_": loss} + if kwargs.get("target", None) is not None: + network[name]["target_"] = kwargs.get("target", None) + if kwargs.get("loss_opts", None) is not None: + network[name]["loss_opts_"] = kwargs.get("loss_opts", None) + return network, name + + +def add_output_layer(network, fromList, name="output", loss="ce", loss_opts=None, cls="softmax", **kwargs): + network[name] = {"class": cls, "from": fromList} + if loss is not None: + network = add_loss_to_layer(network, name, loss, loss_opts=loss_opts, **kwargs) + else: + n_out = kwargs.get("n_out", None) + assert n_out is not None, "either loss or n_out need to be given" + network[name]["n_out"] = n_out + network[name]["is_output_layer"] = True + + if kwargs.get("random_norm_init", False): + network[name]["forward_weights_init"] = "random_normal_initializer(mean=0.0, stddev=0.1)" + if kwargs.get("dropout", None) is not None: + network[name]["dropout"] = kwargs.get("dropout", None) + if kwargs.get("loss_scale", None) is not None: + network[name]["loss_scale"] = kwargs.get("loss_scale", None) + if kwargs.get("activation", None) is not None: + network[name]["class"] = "linear" + network[name]["activation"] = kwargs.get("activation", None) + if kwargs.get("reuse_params", None) is not None: + network[name]["reuse_params"] = kwargs.get("reuse_params", None) + network[name].update(kwargs.get("extra_args", {})) + return network + + +def add_sMBR_output(inNetwork, name="output_ac", output="output", ce_smooth=0.1, **kwargs): + network = copy.deepcopy(inNetwork) + network[output]["loss_scale"] = ce_smooth + network[name] = { + "class": "copy", + "from": output, + "loss": "sprint", + "loss_scale": 1 - ce_smooth, + "loss_opts": { + "sprint_opts": crnn.CustomCRNNSprintTrainingJob.create_sprint_loss_opts( + loss_mode="sMBR", num_sprint_instance=1 + ) + }, + } + return network + + +# full-sum training using sprint FSA (so far only fast_bw loss) +def add_full_sum_output_layer(network, fromList, num_classes, loss="fast_bw", name="output", **kwargs): + output_args = { + "name": name, + "loss": loss, + "loss_opts": { + "sprint_opts": crnn.CustomCRNNSprintTrainingJob.create_sprint_loss_opts(**kwargs), + "tdp_scale": kwargs.get("tdp_scale", 0.0), + }, + "extra_args": {"target": None, "n_out": num_classes}, # no target to infer output size + } + return add_output_layer(network, fromList, **output_args) + + +# decoder output layer using rec-layer unit (including prediction and joint network) +def add_decoder_output_rec_layer(network, fromList, recUnit, optimize_move_layers_out=None, **kwargs): + network = copy.deepcopy(network) + network["output"] = { + "class": "rec", + "from": fromList, + # only relevant for beam_search: e.g. determine length by targets + "cheating": False, + "target": kwargs.get("target", "classes"), + "unit": recUnit, + } + if optimize_move_layers_out is not None: + network["output"]["optimize_move_layers_out"] = optimize_move_layers_out + if kwargs.get("max_seq_len", None) is not None: + network["output"]["max_seq_len"] = kwargs.get("max_seq_len", None) + return network + + +def add_choice_layer(network, name="output_choice", fromList=["output"], initial=0, beam=1, **kwargs): + network[name] = { + "class": "choice", + "target": kwargs.get("target", "classes"), + "from": fromList, + "initial_output": initial, + # only relevant for beam_search: e.g. task='search' + "cheating": "False", # include targets in the beam + "beam_size": beam, + } + if kwargs.get("scheduled_sampling", False): + network[name]["scheduled_sampling"] = kwargs.get("scheduled_sampling", False) + if kwargs.get("input_type", None) is not None: + network[name]["input_type"] = kwargs.get("input_type", None) + # Note: either/none of the following is needed for recognition + # old compile_tf_graph + if kwargs.get("is_stochastic_var", None) is not None: + network[name]["is_stochastic_var"] = kwargs.get("is_stochastic_var", None) + # new compile_tf_graph + if kwargs.get("score_dependent", None) is not None: + network[name]["score_dependent"] = kwargs.get("score_dependent", None) + return network + + +def make_recog_rec_network(trainRecNetwork, removeList=[], update={}, recRemoveList=[], recUpdate={}): + # Note: can not add new layers + def modify(net, toRemove, toUpdate): + network = copy.deepcopy(net) + for lname in net.keys(): + # name pattern match: removal + removed = False + for rk in toRemove: + if rk in lname: + del network[lname] + removed = True + break + if removed: + continue + # name match: dict update + if lname in toUpdate: + network[lname].update(toUpdate[lname]) + return network + + # apply change + recogRecNetwork = modify(trainRecNetwork, removeList, update) + if recRemoveList or recUpdate: + assert recogRecNetwork["output"]["class"] == "rec" + recUnit = modify(recogRecNetwork["output"]["unit"], recRemoveList, recUpdate) + recogRecNetwork["output"]["unit"] = recUnit + return recogRecNetwork + + +# simple zero-encoder estimated internal LM TODO add more +def make_internal_LM_rec_network(recUnit, name, scale, lm_output, num_classes, blankIndex=0, posterior="output"): + assert blankIndex == 0, "assume blank index 0" + assert posterior in recUnit + recUnit[posterior].update({"class": "linear", "activation": "log_softmax"}) + # TODO exclude bias ? + recUnit, fList = add_linear_layer(recUnit, "intLM_logits", lm_output, num_classes, reuse_params=posterior) + recUnit, fList = add_slice_layer(recUnit, "intLM_logits_noBlank", fList, start=1) + recUnit, fList = add_activation_layer(recUnit, "intLM_softmax", fList, "log_softmax") + recUnit, fList = add_padding_layer( + recUnit, "intLM_prior", fList, axes="F", value=0, padding=(1, 0), n_out=num_classes + ) + # log(posterior) - alpha * log(prior) + recUnit, fList = add_eval_layer(recUnit, name, [posterior, fList], "source(0) - %s * source(1)" % (str(scale))) + return recUnit + + +## ----------------------- extra python code ----------------------- ## +# SpecAugment # +def get_spec_augment_mask_python( + codeFile=None, + max_time_num=6, + max_time=5, + max_feature_num=4, + max_feature=5, + conservatvie_step=2000, + feature_limit=None, + customRep={}, +): + path = os.path.dirname(os.path.abspath(__file__)) + if codeFile is None: + if feature_limit is not None: + codeFile = os.path.join(path, "spec_augment_mask_flimit.py") + else: + codeFile = os.path.join(path, "spec_augment_mask.py") + elif codeFile in os.listdir(path): + codeFile = os.path.join(path, codeFile) + with open(codeFile, "r") as f: + python_code = f.read() + + python_code = python_code.replace("max_time_num = 6", "max_time_num = %d" % max_time_num) + python_code = python_code.replace("max_time = 5", "max_time = %d" % max_time) + python_code = python_code.replace("max_feature_num = 4", "max_feature_num = %d" % max_feature_num) + python_code = python_code.replace("max_feature = 5", "max_feature = %d" % max_feature) + python_code = python_code.replace("conservatvie_step = 2000", "conservatvie_step = %d" % conservatvie_step) + + if feature_limit is not None: + assert isinstance(feature_limit, int) + python_code = python_code.replace("feature_limit = 80", "feature_limit = %d" % feature_limit) + + for old, new in customRep.items(): + python_code = python_code.replace("%s" % old, "%s" % new) + return python_code + + +def get_extern_data_python(codeFile=None, nInput=50, nOutput=40): + if codeFile is None: + path = os.path.dirname(os.path.abspath(__file__)) + codeFile = os.path.join(path, "extern_data.py") + with open(codeFile, "r") as f: + python_code = f.read() + + python_code = python_code.replace("nInput", str(nInput)) + python_code = python_code.replace("nOutput", str(nOutput)) + return python_code + + +# custom pretrain construction with down-sampling # +def get_pretrain_python(codeFile=None, repetitions="1", customRep={}): + path = os.path.dirname(os.path.abspath(__file__)) + if codeFile is None: + codeFile = os.path.join(path, "pretrain.py") + elif codeFile in os.listdir(path): + codeFile = os.path.join(path, codeFile) + with open(codeFile, "r") as f: + python_code = f.read() + + if not isinstance(repetitions, str): + repetitions = str(repetitions) + if not repetitions == "1": + python_code = python_code.replace("'repetitions': 1", "'repetitions': %s" % repetitions) + for old, new in customRep.items(): + python_code = python_code.replace("%s" % old, "%s" % new) + return python_code + + +def get_segmental_loss_python(codeFile=None, time_axis=None): + path = os.path.dirname(os.path.abspath(__file__)) + if codeFile is None: + codeFile = os.path.join(path, "segmental_loss.py") + elif codeFile in os.listdir(path): + codeFile = os.path.join(path, codeFile) + with open(codeFile, "r") as f: + python_code = f.read() + if time_axis is not None: + python_code = python_code.replace("axis=0", "axis=%d" % time_axis) + return python_code + + +def get_extra_python(codeFile, customRep={}): + assert codeFile is not None + path = os.path.dirname(os.path.abspath(__file__)) + codeFile = os.path.join(path, codeFile) + with open(codeFile, "r") as f: + python_code = f.read() + for old, new in customRep.items(): + python_code = python_code.replace("%s" % old, "%s" % new) + return python_code diff --git a/common/baselines/tedlium2/hybrid/nn_config/spec_augment_mask.py b/common/baselines/tedlium2/hybrid/nn_config/spec_augment_mask.py new file mode 100644 index 000000000..c98f2aeaa --- /dev/null +++ b/common/baselines/tedlium2/hybrid/nn_config/spec_augment_mask.py @@ -0,0 +1,131 @@ +# for debug only +def summary(name, x): + """ + :param str name: + :param tf.Tensor x: (batch,time,feature) + """ + from returnn.tf.compat import v1 as tf + + # tf.summary.image wants [batch_size, height, width, channels], + # we have (batch, time, feature). + img = tf.expand_dims(x, axis=3) # (batch,time,feature,1) + img = tf.transpose(img, [0, 2, 1, 3]) # (batch,feature,time,1) + tf.summary.image(name, img, max_outputs=10) + tf.summary.scalar("%s_max_abs" % name, tf.reduce_max(tf.abs(x))) + mean = tf.reduce_mean(x) + tf.summary.scalar("%s_mean" % name, mean) + stddev = tf.sqrt(tf.reduce_mean(tf.square(x - mean))) + tf.summary.scalar("%s_stddev" % name, stddev) + tf.summary.histogram("%s_hist" % name, tf.reduce_max(tf.abs(x), axis=2)) + + +def _mask(x, batch_axis, axis, pos, max_amount): + """ + :param tf.Tensor x: (batch,time,feature) + :param int batch_axis: + :param int axis: + :param tf.Tensor pos: (batch,) + :param int|tf.Tensor max_amount: inclusive + """ + from returnn.tf.compat import v1 as tf + + ndim = x.get_shape().ndims + n_batch = tf.shape(x)[batch_axis] + dim = tf.shape(x)[axis] + amount = tf.random_uniform(shape=(n_batch,), minval=1, maxval=max_amount + 1, dtype=tf.int32) + pos2 = tf.minimum(pos + amount, dim) + idxs = tf.expand_dims(tf.range(0, dim), 0) # (1,dim) + pos_bc = tf.expand_dims(pos, 1) # (batch,1) + pos2_bc = tf.expand_dims(pos2, 1) # (batch,1) + cond = tf.logical_and(tf.greater_equal(idxs, pos_bc), tf.less(idxs, pos2_bc)) # (batch,dim) + if batch_axis > axis: + cond = tf.transpose(cond) # (dim,batch) + cond = tf.reshape(cond, [tf.shape(x)[i] if i in (batch_axis, axis) else 1 for i in range(ndim)]) + from TFUtil import where_bc + + x = where_bc(cond, 0.0, x) + return x + + +def random_mask(x, batch_axis, axis, min_num, max_num, max_dims): + """ + :param tf.Tensor x: (batch,time,feature) + :param int batch_axis: + :param int axis: + :param int|tf.Tensor min_num: + :param int|tf.Tensor max_num: inclusive + :param int|tf.Tensor max_dims: inclusive + """ + from returnn.tf.compat import v1 as tf + + n_batch = tf.shape(x)[batch_axis] + if isinstance(min_num, int) and isinstance(max_num, int) and min_num == max_num: + num = min_num + else: + num = tf.random_uniform(shape=(n_batch,), minval=min_num, maxval=max_num + 1, dtype=tf.int32) + # https://github.com/tensorflow/tensorflow/issues/9260 + # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/ + z = -tf.log(-tf.log(tf.random_uniform((n_batch, tf.shape(x)[axis]), 0, 1))) + _, indices = tf.nn.top_k(z, num if isinstance(num, int) else tf.reduce_max(num)) + # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim) + # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)]) + if isinstance(num, int): + for i in range(num): + x = _mask(x, batch_axis=batch_axis, axis=axis, pos=indices[:, i], max_amount=max_dims) + else: + _, x = tf.while_loop( + cond=lambda i, _: tf.less(i, tf.reduce_max(num)), + body=lambda i, x: ( + i + 1, + tf.where( + tf.less(i, num), + _mask(x, batch_axis=batch_axis, axis=axis, pos=indices[:, i], max_amount=max_dims), + x, + ), + ), + loop_vars=(0, x), + ) + return x + + +def transform(data, network): + # to be adjusted (20-50%) + max_time_num = 6 + max_time = 5 + + max_feature_num = 4 + max_feature = 5 + + # halved before this step + conservatvie_step = 2000 + + x = data.placeholder + from returnn.tf.compat import v1 as tf + + # summary("features", x) + step = network.global_train_step + increase_flag = tf.where(tf.greater_equal(step, conservatvie_step), 0, 1) + + def get_masked(): + x_masked = x + x_masked = random_mask( + x_masked, + batch_axis=data.batch_dim_axis, + axis=data.time_dim_axis, + min_num=0, + max_num=tf.maximum(tf.shape(x)[data.time_dim_axis] // (2 * max_time), max_time_num) // (1 + increase_flag), + max_dims=max_time, + ) + x_masked = random_mask( + x_masked, + batch_axis=data.batch_dim_axis, + axis=data.feature_dim_axis, + min_num=0, + max_num=max_feature_num // (1 + increase_flag), + max_dims=max_feature, + ) + # summary("features_mask", x_masked) + return x_masked + + x = network.cond_on_train(get_masked, lambda: x) + return x diff --git a/common/datasets/tedlium2/corpus.py b/common/datasets/tedlium2/corpus.py index e732db251..420adcdfc 100644 --- a/common/datasets/tedlium2/corpus.py +++ b/common/datasets/tedlium2/corpus.py @@ -28,10 +28,7 @@ def get_bliss_corpus_dict(audio_format: str = "wav", output_prefix: str = "datas bliss_corpus_dict = download_data_dict(output_prefix=output_prefix).bliss_nist audio_format_options = { - "wav": { - "output_format": "wav", - "codec": "pcm_s16le", - }, + "wav": {"output_format": "wav", "codec": "pcm_s16le"}, "ogg": {"output_format": "ogg", "codec": "libvorbis"}, "flac": {"output_format": "flac", "codec": "flac"}, } diff --git a/common/datasets/tedlium2/export.py b/common/datasets/tedlium2/export.py index 1919fa8c0..b8a324773 100644 --- a/common/datasets/tedlium2/export.py +++ b/common/datasets/tedlium2/export.py @@ -71,7 +71,7 @@ def _export_lexicon(output_prefix: str = "datasets"): """ lexicon_output_prefix = os.path.join(output_prefix, TEDLIUM_PREFIX, "lexicon") - bliss_lexicon = get_bliss_lexicon(output_prefix=output_prefix) + bliss_lexicon = get_bliss_lexicon(add_unknown_phoneme_and_mapping=False, output_prefix=output_prefix) tk.register_output(os.path.join(lexicon_output_prefix, "tedlium2.lexicon.xml.gz"), bliss_lexicon) g2p_bliss_lexicon = get_g2p_augmented_bliss_lexicon( diff --git a/common/setups/rasr/gmm_system.py b/common/setups/rasr/gmm_system.py index 8609a0caf..a16db9df8 100644 --- a/common/setups/rasr/gmm_system.py +++ b/common/setups/rasr/gmm_system.py @@ -1064,7 +1064,7 @@ def get_gmm_output( gmm_output.acoustic_mixtures = self.mixtures[corpus_key][f"train_{steps.get_prev_gmm_step(step_idx)}"][-1] state_tying_job = allophones.DumpStateTyingJob(self.crp[corpus_key]) - self.jobs[corpus_key]["state_tying"] = state_tying_job + self.jobs[corpus_key]["state_tying_gmm_out"] = state_tying_job tk.register_output( "final_{}_state_tying".format(corpus_key), state_tying_job.out_state_tying, diff --git a/common/setups/rasr/hybrid_system.py b/common/setups/rasr/hybrid_system.py index 34acb12c8..1a4f5a54d 100644 --- a/common/setups/rasr/hybrid_system.py +++ b/common/setups/rasr/hybrid_system.py @@ -1,4 +1,4 @@ -__all__ = ["HybridSystem"] +__all__ = ["HybridArgs", "HybridSystem"] import copy import itertools @@ -24,7 +24,7 @@ from i6_core.mm import CreateDummyMixturesJob from i6_core.returnn import ReturnnComputePriorJobV2 -from .nn_system import NnSystem +from .nn_system import NnSystem, returnn_training from .hybrid_decoder import HybridDecoder from .util import ( @@ -94,17 +94,13 @@ def __init__( self.cv_corpora = [] self.devtrain_corpora = [] - self.train_input_data = ( - None - ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] - self.cv_input_data = ( - None - ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] - self.devtrain_input_data = ( - None - ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] - self.dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]] - self.test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]] + self.train_input_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None + self.cv_input_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None + self.devtrain_input_data: Optional[ + Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]] + ] = None + self.dev_input_data: Optional[Dict[str, ReturnnRasrDataInput]] = None + self.test_input_data: Optional[Dict[str, ReturnnRasrDataInput]] = None self.train_cv_pairing = None @@ -228,16 +224,22 @@ def returnn_training( cv_corpus_key, devtrain_corpus_key=None, ) -> returnn.ReturnnTrainingJob: - if nn_train_args.returnn_root is None: - nn_train_args.returnn_root = self.returnn_root - if nn_train_args.returnn_python_exe is None: - nn_train_args.returnn_python_exe = self.returnn_python_exe + if isinstance(nn_train_args, ReturnnTrainingJobArgs): + if nn_train_args.returnn_root is None: + nn_train_args.returnn_root = self.returnn_root + if nn_train_args.returnn_python_exe is None: + nn_train_args.returnn_python_exe = self.returnn_python_exe - train_job = returnn.ReturnnTrainingJob( + train_job = returnn_training( + name=name, returnn_config=returnn_config, - returnn_root=self.returnn_root, - returnn_python_exe=self.returnn_python_exe, - **nn_train_args, + training_args=nn_train_args, + train_data=self.train_input_data[train_corpus_key], + cv_data=self.cv_input_data[cv_corpus_key], + additional_data={"devtrain": self.devtrain_input_data[devtrain_corpus_key]} + if devtrain_corpus_key is not None + else None, + register_output=False, ) self._add_output_alias_for_train_job( train_job=train_job, @@ -468,7 +470,6 @@ def nn_recog( checkpoints=checkpoints, train_job=train_job, recognition_corpus_key=dev_c, - acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures, **recog_args, ) @@ -484,7 +485,6 @@ def nn_recog( checkpoints=checkpoints, train_job=train_job, recognition_corpus_key=tst_c, - acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures, **r_args, ) @@ -504,7 +504,7 @@ def nn_compile_graph( :return: the TF graph """ graph_compile_job = returnn.CompileTFGraphJob( - returnn_config, + returnn_config=returnn_config, epoch=epoch, returnn_root=self.returnn_root, returnn_python_exe=self.returnn_python_exe, diff --git a/common/setups/rasr/nn_system.py b/common/setups/rasr/nn_system.py index b3feaf101..93321e823 100644 --- a/common/setups/rasr/nn_system.py +++ b/common/setups/rasr/nn_system.py @@ -1,4 +1,4 @@ -__all__ = ["NnSystem"] +__all__ = ["NnSystem", "returnn_training"] import copy from dataclasses import asdict @@ -6,15 +6,14 @@ # -------------------- Sisyphus -------------------- -import sisyphus.toolkit as tk -import sisyphus.global_settings as gs +from sisyphus import tk, gs # -------------------- Recipes -------------------- import i6_core.returnn as returnn from .rasr_system import RasrSystem - +from .util import ReturnnTrainingJobArgs, AllowedReturnnTrainingDataInput # -------------------- Init -------------------- @@ -79,3 +78,35 @@ def get_native_ops(self, op_names: Optional[List[str]]) -> Optional[List[tk.Path if op_name not in self.native_ops.keys(): self.compile_native_op(op_name) return [self.native_ops[op_name] for op_name in op_names] + + +def returnn_training( + name: str, + returnn_config: returnn.ReturnnConfig, + training_args: Union[Dict, ReturnnTrainingJobArgs], + train_data: AllowedReturnnTrainingDataInput, + *, + cv_data: Optional[AllowedReturnnTrainingDataInput] = None, + additional_data: Optional[Dict[str, AllowedReturnnTrainingDataInput]] = None, + register_output: bool = True, +) -> returnn.ReturnnTrainingJob: + assert isinstance(returnn_config, returnn.ReturnnConfig) + + config = copy.deepcopy(returnn_config) + + config.config["train"] = train_data if isinstance(train_data, Dict) else train_data.get_data_dict() + if cv_data is not None: + config.config["dev"] = cv_data if isinstance(cv_data, Dict) else cv_data.get_data_dict() + if additional_data is not None: + config.config["eval_datasets"] = {} + for name, data in additional_data.items(): + config.config["eval_datasets"][name] = data if isinstance(data, Dict) else data.get_data_dict() + returnn_training_job = returnn.ReturnnTrainingJob( + returnn_config=config, + **asdict(training_args) if isinstance(training_args, ReturnnTrainingJobArgs) else training_args, + ) + if register_output: + returnn_training_job.add_alias(f"nn_train/{name}") + tk.register_output(f"nn_train/{name}_learning_rates.png", returnn_training_job.out_plot_lr) + + return returnn_training_job