Skip to content

Commit 901bb55

Browse files
committed
update hybrid decoder
1 parent caf244a commit 901bb55

File tree

1 file changed

+2
-248
lines changed

1 file changed

+2
-248
lines changed

common/setups/rasr/hybrid_decoder.py

Lines changed: 2 additions & 248 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
CombineLmRasrConfig,
2424
)
2525
from .util.decode import (
26-
DevRecognitionParameters,
2726
RecognitionParameters,
2827
SearchJobArgs,
2928
Lattice2CtmArgs,
@@ -48,7 +47,7 @@ class HybridDecoder(BaseDecoder):
4847
def __init__(
4948
self,
5049
rasr_binary_path: tk.Path,
51-
rasr_arch: str = "linux-x86_64-standard",
50+
rasr_arch: "str" = "linux-x86_64-standard",
5251
compress: bool = False,
5352
append: bool = False,
5453
unbuffered: bool = False,
@@ -156,9 +155,8 @@ def recognition(
156155
tf_fwd_input_name: str = "tf-fwd-input",
157156
):
158157
"""
159-
run the recognition, consisting of search, lattice to ctm, and scoring
158+
run the recognitino, consisting of search, lattice to ctm, and scoring
160159
161-
:param name: decoding name
162160
:param returnn_config: RETURNN config for recognition
163161
:param checkpoints: epoch to model checkpoint mapping
164162
:param recognition_parameters: keys are the corpus keys so that recog params can be set for specific eval sets.
@@ -223,247 +221,3 @@ def recognition(
223221
scorer_hyp_param_name=scorer_hyp_param_name,
224222
optimize_pron_lm_scales=optimize_pron_lm_scales,
225223
)
226-
227-
228-
def tune_scales(
229-
decoder: HybridDecoder,
230-
name: str,
231-
returnn_config: Union[returnn.ReturnnConfig, tk.Path],
232-
checkpoints: Dict[int, Union[returnn.Checkpoint, tk.Path]],
233-
lm_configs: Dict[str, LmConfig],
234-
prior_paths: Dict[str, PriorPath],
235-
search_job_args: Union[SearchJobArgs, Dict],
236-
lat_2_ctm_args: Union[Lattice2CtmArgs, Dict],
237-
scorer_args: Union[ScliteScorerArgs, Dict],
238-
optimize_parameters: Union[OptimizeJobArgs, Dict],
239-
epochs: Optional[List[int]] = None,
240-
scorer_hyp_param_name: str = "hyp",
241-
optimize_pron_lm_scales: bool = False,
242-
forward_output_layer: str = "output",
243-
tf_fwd_input_name: str = "tf-fwd-input",
244-
):
245-
"""
246-
this function tunes the prior scale, TDP scale and silence/non-word exit penalties
247-
248-
:return:
249-
"""
250-
recog_params = {
251-
"tune1": [
252-
DevRecognitionParameters(
253-
am_scales=[1.0],
254-
lm_scales=[12.0],
255-
prior_scales=[0.3, 0.5, 0.7],
256-
pronunciation_scales=[1.0],
257-
tdp_scales=[0.1, 0.5, 1.0],
258-
speech_tdps=[],
259-
silence_tdps=[],
260-
nonspeech_tdps=[],
261-
altas=[12.0],
262-
),
263-
],
264-
}
265-
266-
decoder.recognition(
267-
name=name,
268-
returnn_config=returnn_config,
269-
checkpoints=checkpoints,
270-
recognition_parameters=recog_params,
271-
lm_configs=lm_configs,
272-
prior_paths=prior_paths,
273-
search_job_args=search_job_args,
274-
lat_2_ctm_args=lat_2_ctm_args,
275-
scorer_args=scorer_args,
276-
optimize_parameters=optimize_parameters,
277-
epochs=epochs,
278-
scorer_hyp_param_name=scorer_hyp_param_name,
279-
optimize_pron_lm_scales=optimize_pron_lm_scales,
280-
forward_output_layer=forward_output_layer,
281-
tf_fwd_input_name=tf_fwd_input_name,
282-
)
283-
284-
285-
def tune_lm_scale(
286-
decoder: HybridDecoder,
287-
name: str,
288-
returnn_config: Union[returnn.ReturnnConfig, tk.Path],
289-
checkpoints: Dict[int, Union[returnn.Checkpoint, tk.Path]],
290-
lm_configs: Dict[str, LmConfig],
291-
prior_paths: Dict[str, PriorPath],
292-
search_job_args: Union[SearchJobArgs, Dict],
293-
lat_2_ctm_args: Union[Lattice2CtmArgs, Dict],
294-
scorer_args: Union[ScliteScorerArgs, Dict],
295-
optimize_parameters: Union[OptimizeJobArgs, Dict],
296-
epochs: Optional[List[int]] = None,
297-
scorer_hyp_param_name: str = "hyp",
298-
optimize_pron_lm_scales: bool = False,
299-
forward_output_layer: str = "output",
300-
tf_fwd_input_name: str = "tf-fwd-input",
301-
):
302-
"""
303-
tunes the LM scale
304-
305-
:return:
306-
"""
307-
recog_params = {
308-
"tune2": [
309-
DevRecognitionParameters(
310-
am_scales=[1.0],
311-
lm_scales=[12.0],
312-
prior_scales=[0.3, 0.5, 0.7],
313-
pronunciation_scales=[1.0],
314-
tdp_scales=[0.1, 0.5, 1.0],
315-
speech_tdps=[],
316-
silence_tdps=[],
317-
nonspeech_tdps=[],
318-
altas=[0.0],
319-
),
320-
],
321-
}
322-
323-
decoder.recognition(
324-
name=name,
325-
returnn_config=returnn_config,
326-
checkpoints=checkpoints,
327-
recognition_parameters=recog_params,
328-
lm_configs=lm_configs,
329-
prior_paths=prior_paths,
330-
search_job_args=search_job_args,
331-
lat_2_ctm_args=lat_2_ctm_args,
332-
scorer_args=scorer_args,
333-
optimize_parameters=optimize_parameters,
334-
epochs=epochs,
335-
scorer_hyp_param_name=scorer_hyp_param_name,
336-
optimize_pron_lm_scales=optimize_pron_lm_scales,
337-
forward_output_layer=forward_output_layer,
338-
tf_fwd_input_name=tf_fwd_input_name,
339-
)
340-
341-
342-
def tune_search_space(
343-
decoder: HybridDecoder,
344-
name: str,
345-
returnn_config: Union[returnn.ReturnnConfig, tk.Path],
346-
checkpoints: Dict[int, Union[returnn.Checkpoint, tk.Path]],
347-
lm_configs: Dict[str, LmConfig],
348-
prior_paths: Dict[str, PriorPath],
349-
search_job_args: Union[SearchJobArgs, Dict],
350-
lat_2_ctm_args: Union[Lattice2CtmArgs, Dict],
351-
scorer_args: Union[ScliteScorerArgs, Dict],
352-
optimize_parameters: Union[OptimizeJobArgs, Dict],
353-
epochs: Optional[List[int]] = None,
354-
scorer_hyp_param_name: str = "hyp",
355-
optimize_pron_lm_scales: bool = False,
356-
forward_output_layer: str = "output",
357-
tf_fwd_input_name: str = "tf-fwd-input",
358-
):
359-
"""
360-
tunes beam search size and altas
361-
362-
:return:
363-
"""
364-
recog_params = DevRecognitionParameters()
365-
366-
decoder.recognition()
367-
368-
369-
def tune_beam_pruning_limit(
370-
decoder: HybridDecoder,
371-
name: str,
372-
returnn_config: Union[returnn.ReturnnConfig, tk.Path],
373-
checkpoints: Dict[int, Union[returnn.Checkpoint, tk.Path]],
374-
lm_configs: Dict[str, LmConfig],
375-
prior_paths: Dict[str, PriorPath],
376-
search_job_args: Union[SearchJobArgs, Dict],
377-
lat_2_ctm_args: Union[Lattice2CtmArgs, Dict],
378-
scorer_args: Union[ScliteScorerArgs, Dict],
379-
optimize_parameters: Union[OptimizeJobArgs, Dict],
380-
epochs: Optional[List[int]] = None,
381-
scorer_hyp_param_name: str = "hyp",
382-
optimize_pron_lm_scales: bool = False,
383-
forward_output_layer: str = "output",
384-
tf_fwd_input_name: str = "tf-fwd-input",
385-
):
386-
"""
387-
tunes the beam pruning limit
388-
389-
:return:
390-
"""
391-
recog_params = DevRecognitionParameters()
392-
393-
decoder.recognition()
394-
395-
396-
def tune_decoding(
397-
name: str,
398-
*,
399-
rasr_binary_path: tk.Path,
400-
acoustic_model_config: AmRasrConfig,
401-
lexicon_config: LexiconRasrConfig,
402-
returnn_config: Union[returnn.ReturnnConfig, tk.Path],
403-
checkpoints: Dict[int, Union[returnn.Checkpoint, tk.Path]],
404-
lm_configs: Dict[str, LmConfig],
405-
prior_paths: Dict[str, PriorPath],
406-
search_job_args: Union[SearchJobArgs, Dict],
407-
lat_2_ctm_args: Union[Lattice2CtmArgs, Dict],
408-
scorer_args: Union[ScliteScorerArgs, Dict],
409-
optimize_parameters: Union[OptimizeJobArgs, Dict],
410-
rasr_arch: str = "linux-x86_64-standard",
411-
compress: bool = False,
412-
append: bool = False,
413-
unbuffered: bool = False,
414-
compress_after_run: bool = True,
415-
search_job_class: Type[tk.Job] = recog.AdvancedTreeSearchJob,
416-
scorer_job_class: Type[tk.Job] = recog.ScliteJob,
417-
alias_output_prefix: str = "",
418-
returnn_root: Optional[tk.Path] = None,
419-
returnn_python_home: Optional[tk.Path] = None,
420-
returnn_python_exe: Optional[tk.Path] = None,
421-
blas_lib: Optional[tk.Path] = None,
422-
search_numpy_blas: bool = True,
423-
required_native_ops: Optional[List[str]] = None,
424-
extra_configs: Optional[Dict[str, rasr.RasrConfig]] = None,
425-
crp_name: str = "base",
426-
epochs: Optional[List[int]] = None,
427-
scorer_hyp_param_name: str = "hyp",
428-
optimize_pron_lm_scales: bool = False,
429-
forward_output_layer: str = "output",
430-
tf_fwd_input_name: str = "tf-fwd-input",
431-
):
432-
"""
433-
1. TDPs, scales: prior, and TDP [beam-pruning = 14.0, altas = 12.0]
434-
a. TDP: {0.1, 0.5, 1.0}
435-
b. Prior: {0.3, 0.5, 0.7}
436-
c. Silence and non-word phon: {0.0, 4.0, 10.0}
437-
2. LM scale optimization
438-
a. no altas
439-
b. beam-pruning: 14.0, 15.0
440-
3.
441-
a. beam-pruning: 14.0, 15.0
442-
b. altas: 2.0, 4.0, 6.0, 8.0
443-
4. beam pruning-limit: 15k, 10k, 7.5k, 6k, 5k, 4k
444-
445-
:return:
446-
"""
447-
decoder = HybridDecoder(
448-
rasr_binary_path=rasr_binary_path,
449-
rasr_arch=rasr_arch,
450-
compress=compress,
451-
append=append,
452-
unbuffered=unbuffered,
453-
compress_after_run=compress_after_run,
454-
search_job_class=search_job_class,
455-
scorer_job_class=scorer_job_class,
456-
alias_output_prefix=alias_output_prefix,
457-
returnn_root=returnn_root,
458-
returnn_python_home=returnn_python_home,
459-
returnn_python_exe=returnn_python_exe,
460-
blas_lib=blas_lib,
461-
search_numpy_blas=search_numpy_blas,
462-
required_native_ops=required_native_ops,
463-
)
464-
decoder.init_decoder(
465-
acoustic_model_config=acoustic_model_config,
466-
lexicon_config=lexicon_config,
467-
extra_configs=extra_configs,
468-
crp_name=crp_name,
469-
)

0 commit comments

Comments
 (0)