Skip to content

Commit 4b5e42e

Browse files
committed
make use of nn prior optional
1 parent 6be293d commit 4b5e42e

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

common/baselines/tedlium2/hybrid/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def get_corpus_data_inputs(
204204
allophone_labeling=allophone_labeling,
205205
alias_prefix=alias_prefix + "/nn_train_data",
206206
partition_epoch=5,
207-
acoustic_mixtures=gmm_system.outputs["train"]["final"].acoustic_mixtures,
207+
acoustic_mixtures=None,
208208
seq_ordering="laplace:.1000",
209209
)
210210
tk.register_output(f"{alias_prefix}/nn_train_data/features", nn_train_data.features)

common/setups/rasr/hybrid_system.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def nn_recognition(
354354
name: str,
355355
returnn_config: returnn.ReturnnConfig,
356356
checkpoints: Dict[int, returnn.Checkpoint],
357-
train_job: Union[returnn.ReturnnTrainingJob, returnn.ReturnnRasrTrainingJob],
357+
acoustic_mixture_path: Optional[tk.Path], # TODO maybe Optional if prior file provided -> automatically construct dummy file
358358
prior_scales: List[float],
359359
pronunciation_scales: List[float],
360360
lm_scales: List[float],
@@ -370,6 +370,7 @@ def nn_recognition(
370370
use_epoch_for_compile=False,
371371
forward_output_layer="output",
372372
native_ops: Optional[List[str]] = None,
373+
train_job: Optional[Union[returnn.ReturnnTrainingJob, returnn.ReturnnRasrTrainingJob]] = None,
373374
**kwargs,
374375
):
375376
with tk.block(f"{name}_recognition"):
@@ -393,24 +394,28 @@ def nn_recognition(
393394
for pron, lm, prior, epoch in itertools.product(pronunciation_scales, lm_scales, prior_scales, epochs):
394395

395396
assert epoch in checkpoints.keys()
396-
acoustic_mixture_path = CreateDummyMixturesJob(
397-
num_mixtures=returnn_config.config["extern_data"]["classes"]["dim"],
398-
num_features=returnn_config.config["extern_data"]["data"]["dim"],
399-
).out_mixtures
400-
lmgc_scorer = rasr.GMMFeatureScorer(acoustic_mixture_path)
401-
prior_job = ReturnnComputePriorJobV2(
402-
model_checkpoint=checkpoints[epoch],
403-
returnn_config=train_job.returnn_config,
404-
returnn_python_exe=train_job.returnn_python_exe,
405-
returnn_root=train_job.returnn_root,
406-
log_verbosity=train_job.returnn_config.post_config["log_verbosity"],
407-
)
408-
prior_job.add_alias("extract_nn_prior/" + name)
409-
prior_file = prior_job.out_prior_xml_file
410-
assert prior_file is not None
411-
397+
prior_file = None
398+
lmgc_scorer = None
399+
if acoustic_mixture_path is None:
400+
assert train_job is not None, "Need ReturnnTrainingJob for computation of priors"
401+
tmp_acoustic_mixture_path = CreateDummyMixturesJob(
402+
num_mixtures=returnn_config.config["extern_data"]["classes"]["dim"],
403+
num_features=returnn_config.config["extern_data"]["data"]["dim"],
404+
).out_mixtures
405+
lmgc_scorer = rasr.GMMFeatureScorer(tmp_acoustic_mixture_path)
406+
prior_job = ReturnnComputePriorJobV2(
407+
model_checkpoint=checkpoints[epoch],
408+
returnn_config=train_job.returnn_config,
409+
returnn_python_exe=train_job.returnn_python_exe,
410+
returnn_root=train_job.returnn_root,
411+
log_verbosity=train_job.returnn_config.post_config["log_verbosity"],
412+
)
413+
prior_job.add_alias("extract_nn_prior/" + name)
414+
prior_file = prior_job.out_prior_xml_file
415+
else:
416+
tmp_acoustic_mixture_path = acoustic_mixture_path
412417
scorer = rasr.PrecomputedHybridFeatureScorer(
413-
prior_mixtures=acoustic_mixture_path,
418+
prior_mixtures=tmp_acoustic_mixture_path, # This needs to be a new variable otherwise nesting causes undesired behavior
414419
priori_scale=prior,
415420
prior_file=prior_file,
416421
)
@@ -468,6 +473,7 @@ def nn_recog(
468473
name=f"{train_corpus_key}-{train_name}-{recog_name}",
469474
returnn_config=returnn_config,
470475
checkpoints=checkpoints,
476+
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
471477
train_job=train_job,
472478
recognition_corpus_key=dev_c,
473479
**recog_args,
@@ -483,6 +489,7 @@ def nn_recog(
483489
name=f"{train_name}-{recog_name}",
484490
returnn_config=returnn_config,
485491
checkpoints=checkpoints,
492+
acoustic_mixture_path=self.train_input_data[train_corpus_key].acoustic_mixtures,
486493
train_job=train_job,
487494
recognition_corpus_key=tst_c,
488495
**r_args,

0 commit comments

Comments
 (0)