1- __all__ = ["HybridSystem" ]
1+ __all__ = ["HybridArgs" , " HybridSystem" ]
22
33import copy
44import itertools
@@ -94,17 +94,13 @@ def __init__(
9494 self .cv_corpora = []
9595 self .devtrain_corpora = []
9696
97- self .train_input_data = (
98- None
99- ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
100- self .cv_input_data = (
101- None
102- ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
103- self .devtrain_input_data = (
104- None
105- ) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
106- self .dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
107- self .test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97+ self .train_input_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None
98+ self .cv_input_data : Optional [Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]] = None
99+ self .devtrain_input_data : Optional [
100+ Dict [str , Union [ReturnnRasrDataInput , AllowedReturnnTrainingDataInput ]]
101+ ] = None
102+ self .dev_input_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None
103+ self .test_input_data : Optional [Dict [str , ReturnnRasrDataInput ]] = None
108104
109105 self .train_cv_pairing = None
110106
@@ -373,7 +369,6 @@ def nn_recognition(
373369 use_epoch_for_compile = False ,
374370 forward_output_layer = "output" ,
375371 native_ops : Optional [List [str ]] = None ,
376- acoustic_mixture_path : Optional [tk .Path ] = None ,
377372 ** kwargs ,
378373 ):
379374 with tk .block (f"{ name } _recognition" ):
@@ -395,6 +390,7 @@ def nn_recognition(
395390 epochs = epochs if epochs is not None else list (checkpoints .keys ())
396391
397392 for pron , lm , prior , epoch in itertools .product (pronunciation_scales , lm_scales , prior_scales , epochs ):
393+
398394 assert epoch in checkpoints .keys ()
399395 acoustic_mixture_path = CreateDummyMixturesJob (
400396 num_mixtures = returnn_config .config ["extern_data" ]["classes" ]["dim" ],
@@ -408,16 +404,15 @@ def nn_recognition(
408404 returnn_root = train_job .returnn_root ,
409405 log_verbosity = train_job .returnn_config .post_config ["log_verbosity" ],
410406 )
411-
412407 prior_job .add_alias ("extract_nn_prior/" + name )
413408 prior_file = prior_job .out_prior_xml_file
414409 assert prior_file is not None
410+
415411 scorer = rasr .PrecomputedHybridFeatureScorer (
416412 prior_mixtures = acoustic_mixture_path ,
417413 priori_scale = prior ,
418414 prior_file = prior_file ,
419415 )
420- assert acoustic_mixture_path is not None
421416
422417 if use_epoch_for_compile :
423418 tf_graph = self .nn_compile_graph (name , returnn_config , epoch = epoch )
@@ -474,7 +469,6 @@ def nn_recog(
474469 checkpoints = checkpoints ,
475470 train_job = train_job ,
476471 recognition_corpus_key = dev_c ,
477- acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
478472 ** recog_args ,
479473 )
480474
@@ -490,7 +484,6 @@ def nn_recog(
490484 checkpoints = checkpoints ,
491485 train_job = train_job ,
492486 recognition_corpus_key = tst_c ,
493- acoustic_mixture_path = self .train_input_data [train_corpus_key ].acoustic_mixtures ,
494487 ** r_args ,
495488 )
496489
@@ -509,12 +502,8 @@ def nn_compile_graph(
509502 e.g. `def get_network(epoch=...)` in the config
510503 :return: the TF graph
511504 """
512- # TODO remove, temporary hack
513- cfg = returnn_config
514- if "pretrain" in cfg .config .keys ():
515- del cfg .config ["pretrain" ]
516505 graph_compile_job = returnn .CompileTFGraphJob (
517- cfg ,
506+ returnn_config = returnn_config ,
518507 epoch = epoch ,
519508 returnn_root = self .returnn_root ,
520509 returnn_python_exe = self .returnn_python_exe ,
0 commit comments