Skip to content

Commit e61f694

Browse files
committed
update hybrid system
1 parent 901bb55 commit e61f694

File tree

1 file changed

+11
-22
lines changed

1 file changed

+11
-22
lines changed

common/setups/rasr/hybrid_system.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["HybridSystem"]
1+
__all__ = ["HybridArgs", "HybridSystem"]
22

33
import copy
44
import itertools
@@ -94,17 +94,13 @@ def __init__(
9494
self.cv_corpora = []
9595
self.devtrain_corpora = []
9696

97-
self.train_input_data = (
98-
None
99-
) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
100-
self.cv_input_data = (
101-
None
102-
) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
103-
self.devtrain_input_data = (
104-
None
105-
) # type:Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]]
106-
self.dev_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
107-
self.test_input_data = None # type:Optional[Dict[str, ReturnnRasrDataInput]]
97+
self.train_input_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None
98+
self.cv_input_data: Optional[Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]] = None
99+
self.devtrain_input_data: Optional[
100+
Dict[str, Union[ReturnnRasrDataInput, AllowedReturnnTrainingDataInput]]
101+
] = None
102+
self.dev_input_data: Optional[Dict[str, ReturnnRasrDataInput]] = None
103+
self.test_input_data: Optional[Dict[str, ReturnnRasrDataInput]] = None
108104

109105
self.train_cv_pairing = None
110106

@@ -373,7 +369,6 @@ def nn_recognition(
373369
use_epoch_for_compile=False,
374370
forward_output_layer="output",
375371
native_ops: Optional[List[str]] = None,
376-
acoustic_mixture_path: Optional[tk.Path] = None,
377372
**kwargs,
378373
):
379374
with tk.block(f"{name}_recognition"):
@@ -395,6 +390,7 @@ def nn_recognition(
395390
epochs = epochs if epochs is not None else list(checkpoints.keys())
396391

397392
for pron, lm, prior, epoch in itertools.product(pronunciation_scales, lm_scales, prior_scales, epochs):
393+
398394
assert epoch in checkpoints.keys()
399395
acoustic_mixture_path = CreateDummyMixturesJob(
400396
num_mixtures=returnn_config.config["extern_data"]["classes"]["dim"],
@@ -408,16 +404,15 @@ def nn_recognition(
408404
returnn_root=train_job.returnn_root,
409405
log_verbosity=train_job.returnn_config.post_config["log_verbosity"],
410406
)
411-
412407
prior_job.add_alias("extract_nn_prior/" + name)
413408
prior_file = prior_job.out_prior_xml_file
414409
assert prior_file is not None
410+
415411
scorer = rasr.PrecomputedHybridFeatureScorer(
416412
prior_mixtures=acoustic_mixture_path,
417413
priori_scale=prior,
418414
prior_file=prior_file,
419415
)
420-
assert acoustic_mixture_path is not None
421416

422417
if use_epoch_for_compile:
423418
tf_graph = self.nn_compile_graph(name, returnn_config, epoch=epoch)
@@ -474,7 +469,6 @@ def nn_recog(
474469
checkpoints=checkpoints,
475470
train_job=train_job,
476471
recognition_corpus_key=dev_c,
477-
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
478472
**recog_args,
479473
)
480474

@@ -490,7 +484,6 @@ def nn_recog(
490484
checkpoints=checkpoints,
491485
train_job=train_job,
492486
recognition_corpus_key=tst_c,
493-
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
494487
**r_args,
495488
)
496489

@@ -509,12 +502,8 @@ def nn_compile_graph(
509502
e.g. `def get_network(epoch=...)` in the config
510503
:return: the TF graph
511504
"""
512-
# TODO remove, temporary hack
513-
cfg = returnn_config
514-
if "pretrain" in cfg.config.keys():
515-
del cfg.config["pretrain"]
516505
graph_compile_job = returnn.CompileTFGraphJob(
517-
cfg,
506+
returnn_config=returnn_config,
518507
epoch=epoch,
519508
returnn_root=self.returnn_root,
520509
returnn_python_exe=self.returnn_python_exe,

0 commit comments

Comments
 (0)