2121 add_tf_flow_to_base_flow ,
2222)
2323from i6_core .util import MultiPath , MultiOutputPath
24+ from i6_core .mm import CreateDummyMixturesJob
25+ from i6_core .returnn import ReturnnComputePriorJobV2
2426
25- from .nn_system import NnSystem
27+ from .hybrid_decoder import HybridDecoder
28+ from .nn_system import NnSystem , returnn_training
2629
2730from .util import (
2831 RasrInitArgs ,
2932 ReturnnRasrDataInput ,
30- OggZipHdfDataInput ,
3133 HybridArgs ,
3234 NnRecogArgs ,
3335 RasrSteps ,
3436 NnForcedAlignArgs ,
37+ ReturnnTrainingJobArgs ,
38+ AllowedReturnnTrainingDataInput ,
3539)
3640
3741# -------------------- Init --------------------
@@ -90,11 +94,13 @@ def __init__(
9094 self .cv_corpora = []
9195 self .devtrain_corpora = []
9296
93- self .train_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
94- self .cv_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
95- self .devtrain_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
96- self .dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97- self .test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97+ self .train_input_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None
98+ self .cv_input_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None
99+ self .devtrain_input_data : Optional [
100+ Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]
101+ ] = None
102+ self .dev_input_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None
103+ self .test_input_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None
98104
99105 self .train_cv_pairing = None
100106
@@ -128,9 +134,9 @@ def _add_output_alias_for_train_job(
128134 def init_system (
129135 self ,
130136 rasr_init_args : RasrInitArgs ,
131- train_data : Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]],
132- cv_data : Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]],
133- devtrain_data : Optional [Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]]] = None ,
137+ train_data : Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]],
138+ cv_data : Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]],
139+ devtrain_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None ,
134140 dev_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None ,
135141 test_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None ,
136142 train_cv_pairing : Optional [List [Tuple [str , ...]]] = None , # List[Tuple[trn_c, cv_c, name, dvtr_c]]
@@ -211,27 +217,29 @@ def generate_lattices(self):
211217
212218 def returnn_training (
213219 self ,
214- name ,
215- returnn_config ,
216- nn_train_args ,
220+ name : str ,
221+ returnn_config : returnn . ReturnnConfig ,
222+ nn_train_args : Union [ Dict , ReturnnTrainingJobArgs ] ,
217223 train_corpus_key ,
218224 cv_corpus_key ,
219225 devtrain_corpus_key = None ,
220- ):
221- assert isinstance (returnn_config , returnn .ReturnnConfig )
222-
223- returnn_config .config ["train" ] = self .train_input_data [train_corpus_key ].get_data_dict ()
224- returnn_config .config ["dev" ] = self .cv_input_data [cv_corpus_key ].get_data_dict ()
225- if devtrain_corpus_key is not None :
226- returnn_config .config ["eval_datasets" ] = {
227- "devtrain" : self .devtrain_input_data [devtrain_corpus_key ].get_data_dict ()
228- }
229-
230- train_job = returnn .ReturnnTrainingJob (
226+ ) -> returnn .ReturnnTrainingJob :
227+ if isinstance (nn_train_args , ReturnnTrainingJobArgs ):
228+ if nn_train_args .returnn_root is None :
229+ nn_train_args .returnn_root = self .returnn_root
230+ if nn_train_args .returnn_python_exe is None :
231+ nn_train_args .returnn_python_exe = self .returnn_python_exe
232+
233+ train_job = returnn_training (
234+ name = name ,
231235 returnn_config = returnn_config ,
232- returnn_root = self .returnn_root ,
233- returnn_python_exe = self .returnn_python_exe ,
234- ** nn_train_args ,
236+ training_args = nn_train_args ,
237+ train_data = self .train_input_data [train_corpus_key ],
238+ cv_data = self .cv_input_data [cv_corpus_key ],
239+ additional_data = {"devtrain" : self .devtrain_input_data [devtrain_corpus_key ]}
240+ if devtrain_corpus_key is not None
241+ else None ,
242+ register_output = False ,
235243 )
236244 self ._add_output_alias_for_train_job (
237245 train_job = train_job ,
@@ -346,7 +354,9 @@ def nn_recognition(
346354 name : str ,
347355 returnn_config : returnn .ReturnnConfig ,
348356 checkpoints : Dict [int , returnn .Checkpoint ],
349- acoustic_mixture_path : tk .Path , # TODO maybe Optional if prior file provided -> automatically construct dummy file
357+ acoustic_mixture_path : Optional [
358+ tk .Path
359+ ], # TODO maybe Optional if prior file provided -> automatically construct dummy file
350360 prior_scales : List [float ],
351361 pronunciation_scales : List [float ],
352362 lm_scales : List [float ],
@@ -362,6 +372,7 @@ def nn_recognition(
362372 use_epoch_for_compile = False ,
363373 forward_output_layer = "output" ,
364374 native_ops : Optional [List [str ]] = None ,
375+ train_job : Optional [Union [returnn .ReturnnTrainingJob , returnn .ReturnnRasrTrainingJob ]] = None ,
365376 ** kwargs ,
366377 ):
367378 with tk .block (f"{ name } _recognition" ):
@@ -383,17 +394,37 @@ def nn_recognition(
383394 epochs = epochs if epochs is not None else list (checkpoints .keys ())
384395
385396 for pron , lm , prior , epoch in itertools .product (pronunciation_scales , lm_scales , prior_scales , epochs ):
386- assert epoch in checkpoints .keys ()
387- assert acoustic_mixture_path is not None
388-
389- if use_epoch_for_compile :
390- tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
391397
398+ assert epoch in checkpoints .keys ()
399+ prior_file = None
400+ lmgc_scorer = None
401+ if acoustic_mixture_path is None :
402+ assert train_job is not None , "Need ReturnnTrainingJob for computation of priors"
403+ tmp_acoustic_mixture_path = CreateDummyMixturesJob (
404+ num_mixtures = returnn_config .config ["extern_data" ]["classes" ]["dim" ],
405+ num_features = returnn_config .config ["extern_data" ]["data" ]["dim" ],
406+ ).out_mixtures
407+ lmgc_scorer = rasr .GMMFeatureScorer (tmp_acoustic_mixture_path )
408+ prior_job = ReturnnComputePriorJobV2 (
409+ model_checkpoint = checkpoints [epoch ],
410+ returnn_config = train_job .returnn_config ,
411+ returnn_python_exe = train_job .returnn_python_exe ,
412+ returnn_root = train_job .returnn_root ,
413+ log_verbosity = train_job .returnn_config .post_config ["log_verbosity" ],
414+ )
415+ prior_job .add_alias ("extract_nn_prior/" + name )
416+ prior_file = prior_job .out_prior_xml_file
417+ else :
418+ tmp_acoustic_mixture_path = acoustic_mixture_path
392419 scorer = rasr .PrecomputedHybridFeatureScorer (
393- prior_mixtures = acoustic_mixture_path ,
420+ prior_mixtures = tmp_acoustic_mixture_path , # This needs to be a new variable otherwise nesting causes undesired behavior
394421 priori_scale = prior ,
422+ prior_file = prior_file ,
395423 )
396424
425+ if use_epoch_for_compile :
426+ tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
427+
397428 tf_flow = make_precomputed_hybrid_tf_feature_flow (
398429 tf_checkpoint = checkpoints [epoch ],
399430 tf_graph = tf_graph ,
@@ -419,6 +450,8 @@ def nn_recognition(
419450 parallelize_conversion = parallelize_conversion ,
420451 rtf = rtf ,
421452 mem = mem ,
453+ lmgc_alias = f"lmgc/{ name } /{ recognition_corpus_key } -{ recog_name } " ,
454+ lmgc_scorer = lmgc_scorer ,
422455 ** kwargs ,
423456 )
424457
@@ -429,14 +462,21 @@ def nn_recog(
429462 returnn_config : Path ,
430463 checkpoints : Dict [int , returnn .Checkpoint ],
431464 step_args : HybridArgs ,
465+ train_job : Union [returnn .ReturnnTrainingJob , returnn .ReturnnRasrTrainingJob ],
432466 ):
433467 for recog_name , recog_args in step_args .recognition_args .items ():
468+ recog_args = copy .deepcopy (recog_args )
469+ whitelist = recog_args .pop ("training_whitelist" , None )
470+ if whitelist :
471+ if train_name not in whitelist :
472+ continue
434473 for dev_c in self .dev_corpora :
435474 self .nn_recognition (
436475 name = f"{ train_corpus_key } -{ train_name } -{ recog_name } " ,
437476 returnn_config = returnn_config ,
438477 checkpoints = checkpoints ,
439478 acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
479+ train_job = train_job ,
440480 recognition_corpus_key = dev_c ,
441481 ** recog_args ,
442482 )
@@ -452,6 +492,7 @@ def nn_recog(
452492 returnn_config = returnn_config ,
453493 checkpoints = checkpoints ,
454494 acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
495+ train_job = train_job ,
455496 recognition_corpus_key = tst_c ,
456497 ** r_args ,
457498 )
@@ -472,7 +513,7 @@ def nn_compile_graph(
472513 :return: the TF graph
473514 """
474515 graph_compile_job = returnn .CompileTFGraphJob (
475- returnn_config ,
516+ returnn_config = returnn_config ,
476517 epoch = epoch ,
477518 returnn_root = self .returnn_root ,
478519 returnn_python_exe = self .returnn_python_exe ,
@@ -509,7 +550,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
509550 train_corpus_key = trn_c ,
510551 cv_corpus_key = cv_c ,
511552 )
512- else :
553+ elif isinstance ( self . train_input_data [ trn_c ], AllowedReturnnTrainingDataInput ) :
513554 returnn_train_job = self .returnn_training (
514555 name = name ,
515556 returnn_config = step_args .returnn_training_configs [name ],
@@ -518,6 +559,8 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
518559 cv_corpus_key = cv_c ,
519560 devtrain_corpus_key = dvtr_c ,
520561 )
562+ else :
563+ raise NotImplementedError
521564
522565 returnn_recog_config = step_args .returnn_recognition_configs .get (
523566 name , step_args .returnn_training_configs [name ]
@@ -529,6 +572,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
529572 returnn_config = returnn_recog_config ,
530573 checkpoints = returnn_train_job .out_checkpoints ,
531574 step_args = step_args ,
575+ train_job = returnn_train_job ,
532576 )
533577
534578 def run_nn_recog_step (self , step_args : NnRecogArgs ):
0 commit comments