1- __all__ = ["HybridArgs" , " HybridSystem" ]
1+ __all__ = ["HybridSystem" ]
22
33import copy
44import itertools
2121 add_tf_flow_to_base_flow ,
2222)
2323from i6_core .util import MultiPath , MultiOutputPath
24+ from i6_core .mm import CreateDummyMixturesJob
25+ from i6_core .returnn import ReturnnComputePriorJobV2
2426
2527from .nn_system import NnSystem
28+ from .hybrid_decoder import HybridDecoder
2629
2730from .util import (
2831 RasrInitArgs ,
2932 ReturnnRasrDataInput ,
30- OggZipHdfDataInput ,
3133 HybridArgs ,
3234 NnRecogArgs ,
3335 RasrSteps ,
3436 NnForcedAlignArgs ,
37+ ReturnnTrainingJobArgs ,
38+ AllowedReturnnTrainingDataInput ,
3539)
3640
3741# -------------------- Init --------------------
@@ -90,9 +94,15 @@ def __init__(
9094 self .cv_corpora = []
9195 self .devtrain_corpora = []
9296
93- self .train_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
94- self .cv_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
95- self .devtrain_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97+ self .train_input_data = (
98+ None
99+ ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
100+ self .cv_input_data = (
101+ None
102+ ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
103+ self .devtrain_input_data = (
104+ None
105+ ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
96106 self .dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97107 self .test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
98108
@@ -128,9 +138,9 @@ def _add_output_alias_for_train_job(
128138 def init_system (
129139 self ,
130140 rasr_init_args : RasrInitArgs ,
131- train_data : Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]],
132- cv_data : Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]],
133- devtrain_data : Optional [Dict [str , Union [ReturnnRasrDataInput , OggZipHdfDataInput ]]] = None ,
141+ train_data : Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]],
142+ cv_data : Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]],
143+ devtrain_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None ,
134144 dev_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None ,
135145 test_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None ,
136146 train_cv_pairing : Optional [List [Tuple [str , ...]]] = None , # List[Tuple[trn_c, cv_c, name, dvtr_c]]
@@ -211,21 +221,17 @@ def generate_lattices(self):
211221
212222 def returnn_training (
213223 self ,
214- name ,
215- returnn_config ,
216- nn_train_args ,
224+ name : str ,
225+ returnn_config : returnn . ReturnnConfig ,
226+ nn_train_args : Union [ Dict , ReturnnTrainingJobArgs ] ,
217227 train_corpus_key ,
218228 cv_corpus_key ,
219229 devtrain_corpus_key = None ,
220- ):
221- assert isinstance (returnn_config , returnn .ReturnnConfig )
222-
223- returnn_config .config ["train" ] = self .train_input_data [train_corpus_key ].get_data_dict ()
224- returnn_config .config ["dev" ] = self .cv_input_data [cv_corpus_key ].get_data_dict ()
225- if devtrain_corpus_key is not None :
226- returnn_config .config ["eval_datasets" ] = {
227- "devtrain" : self .devtrain_input_data [devtrain_corpus_key ].get_data_dict ()
228- }
230+ ) -> returnn .ReturnnTrainingJob :
231+ if nn_train_args .returnn_root is None :
232+ nn_train_args .returnn_root = self .returnn_root
233+ if nn_train_args .returnn_python_exe is None :
234+ nn_train_args .returnn_python_exe = self .returnn_python_exe
229235
230236 train_job = returnn .ReturnnTrainingJob (
231237 returnn_config = returnn_config ,
@@ -346,7 +352,7 @@ def nn_recognition(
346352 name : str ,
347353 returnn_config : returnn .ReturnnConfig ,
348354 checkpoints : Dict [int , returnn .Checkpoint ],
349- acoustic_mixture_path : tk . Path , # TODO maybe Optional if prior file provided -> automatically construct dummy file
355+ train_job : Union [ returnn . ReturnnTrainingJob , returnn . ReturnnRasrTrainingJob ],
350356 prior_scales : List [float ],
351357 pronunciation_scales : List [float ],
352358 lm_scales : List [float ],
@@ -384,15 +390,31 @@ def nn_recognition(
384390
385391 for pron , lm , prior , epoch in itertools .product (pronunciation_scales , lm_scales , prior_scales , epochs ):
386392 assert epoch in checkpoints .keys ()
387- assert acoustic_mixture_path is not None
388-
389- if use_epoch_for_compile :
390- tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
393+ acoustic_mixture_path = CreateDummyMixturesJob (
394+ num_mixtures = returnn_config .config ["extern_data" ]["classes" ]["dim" ],
395+ num_features = returnn_config .config ["extern_data" ]["data" ]["dim" ],
396+ ).out_mixtures
397+ lmgc_scorer = rasr .GMMFeatureScorer (acoustic_mixture_path )
398+ prior_job = ReturnnComputePriorJobV2 (
399+ model_checkpoint = checkpoints [epoch ],
400+ returnn_config = train_job .returnn_config ,
401+ returnn_python_exe = train_job .returnn_python_exe ,
402+ returnn_root = train_job .returnn_root ,
403+ log_verbosity = train_job .returnn_config .post_config ["log_verbosity" ],
404+ )
391405
406+ prior_job .add_alias ("extract_nn_prior/" + name )
407+ prior_file = prior_job .out_prior_xml_file
408+ assert prior_file is not None
392409 scorer = rasr .PrecomputedHybridFeatureScorer (
393410 prior_mixtures = acoustic_mixture_path ,
394411 priori_scale = prior ,
412+ prior_file = prior_file ,
395413 )
414+ assert acoustic_mixture_path is not None
415+
416+ if use_epoch_for_compile :
417+ tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
396418
397419 tf_flow = make_precomputed_hybrid_tf_feature_flow (
398420 tf_checkpoint = checkpoints [epoch ],
@@ -419,6 +441,8 @@ def nn_recognition(
419441 parallelize_conversion = parallelize_conversion ,
420442 rtf = rtf ,
421443 mem = mem ,
444+ lmgc_alias = f"lmgc/{ name } /{ recognition_corpus_key } -{ recog_name } " ,
445+ lmgc_scorer = lmgc_scorer ,
422446 ** kwargs ,
423447 )
424448
@@ -429,15 +453,22 @@ def nn_recog(
429453 returnn_config : Path ,
430454 checkpoints : Dict [int , returnn .Checkpoint ],
431455 step_args : HybridArgs ,
456+ train_job : Union [returnn .ReturnnTrainingJob , returnn .ReturnnRasrTrainingJob ],
432457 ):
433458 for recog_name , recog_args in step_args .recognition_args .items ():
459+ recog_args = copy .deepcopy (recog_args )
460+ whitelist = recog_args .pop ("training_whitelist" , None )
461+ if whitelist :
462+ if train_name not in whitelist :
463+ continue
434464 for dev_c in self .dev_corpora :
435465 self .nn_recognition (
436466 name = f"{ train_corpus_key } -{ train_name } -{ recog_name } " ,
437467 returnn_config = returnn_config ,
438468 checkpoints = checkpoints ,
439- acoustic_mixture_path = self . train_input_data [ train_corpus_key ]. acoustic_mixtures ,
469+ train_job = train_job ,
440470 recognition_corpus_key = dev_c ,
471+ acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
441472 ** recog_args ,
442473 )
443474
@@ -451,8 +482,9 @@ def nn_recog(
451482 name = f"{ train_name } -{ recog_name } " ,
452483 returnn_config = returnn_config ,
453484 checkpoints = checkpoints ,
454- acoustic_mixture_path = self . train_input_data [ train_corpus_key ]. acoustic_mixtures ,
485+ train_job = train_job ,
455486 recognition_corpus_key = tst_c ,
487+ acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
456488 ** r_args ,
457489 )
458490
@@ -509,7 +541,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
509541 train_corpus_key = trn_c ,
510542 cv_corpus_key = cv_c ,
511543 )
512- else :
544+ elif isinstance ( self . train_input_data [ trn_c ], AllowedReturnnTrainingDataInput ) :
513545 returnn_train_job = self .returnn_training (
514546 name = name ,
515547 returnn_config = step_args .returnn_training_configs [name ],
@@ -518,6 +550,8 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
518550 cv_corpus_key = cv_c ,
519551 devtrain_corpus_key = dvtr_c ,
520552 )
553+ else :
554+ raise NotImplementedError
521555
522556 returnn_recog_config = step_args .returnn_recognition_configs .get (
523557 name , step_args .returnn_training_configs [name ]
@@ -529,6 +563,7 @@ def run_nn_step(self, step_name: str, step_args: HybridArgs):
529563 returnn_config = returnn_recog_config ,
530564 checkpoints = returnn_train_job .out_checkpoints ,
531565 step_args = step_args ,
566+ train_job = returnn_train_job ,
532567 )
533568
534569 def run_nn_recog_step (self , step_args : NnRecogArgs ):
0 commit comments