2323 CombineLmRasrConfig ,
2424)
2525from .util .decode import (
26- DevRecognitionParameters ,
2726 RecognitionParameters ,
2827 SearchJobArgs ,
2928 Lattice2CtmArgs ,
@@ -48,7 +47,7 @@ class HybridDecoder(BaseDecoder):
4847 def __init__ (
4948 self ,
5049 rasr_binary_path : tk .Path ,
51- rasr_arch : str = "linux-x86_64-standard" ,
50+ rasr_arch : " str" = "linux-x86_64-standard" ,
5251 compress : bool = False ,
5352 append : bool = False ,
5453 unbuffered : bool = False ,
@@ -156,9 +155,8 @@ def recognition(
156155 tf_fwd_input_name : str = "tf-fwd-input" ,
157156 ):
158157 """
159- run the recognition , consisting of search, lattice to ctm, and scoring
158+ run the recognitino , consisting of search, lattice to ctm, and scoring
160159
161- :param name: decoding name
162160 :param returnn_config: RETURNN config for recognition
163161 :param checkpoints: epoch to model checkpoint mapping
164162 :param recognition_parameters: keys are the corpus keys so that recog params can be set for specific eval sets.
@@ -223,247 +221,3 @@ def recognition(
223221 scorer_hyp_param_name = scorer_hyp_param_name ,
224222 optimize_pron_lm_scales = optimize_pron_lm_scales ,
225223 )
226-
227-
228- def tune_scales (
229- decoder : HybridDecoder ,
230- name : str ,
231- returnn_config : Union [returnn .ReturnnConfig , tk .Path ],
232- checkpoints : Dict [int , Union [returnn .Checkpoint , tk .Path ]],
233- lm_configs : Dict [str , LmConfig ],
234- prior_paths : Dict [str , PriorPath ],
235- search_job_args : Union [SearchJobArgs , Dict ],
236- lat_2_ctm_args : Union [Lattice2CtmArgs , Dict ],
237- scorer_args : Union [ScliteScorerArgs , Dict ],
238- optimize_parameters : Union [OptimizeJobArgs , Dict ],
239- epochs : Optional [List [int ]] = None ,
240- scorer_hyp_param_name : str = "hyp" ,
241- optimize_pron_lm_scales : bool = False ,
242- forward_output_layer : str = "output" ,
243- tf_fwd_input_name : str = "tf-fwd-input" ,
244- ):
245- """
246- this function tunes the prior scale, TDP scale and silence/non-word exit penalties
247-
248- :return:
249- """
250- recog_params = {
251- "tune1" : [
252- DevRecognitionParameters (
253- am_scales = [1.0 ],
254- lm_scales = [12.0 ],
255- prior_scales = [0.3 , 0.5 , 0.7 ],
256- pronunciation_scales = [1.0 ],
257- tdp_scales = [0.1 , 0.5 , 1.0 ],
258- speech_tdps = [],
259- silence_tdps = [],
260- nonspeech_tdps = [],
261- altas = [12.0 ],
262- ),
263- ],
264- }
265-
266- decoder .recognition (
267- name = name ,
268- returnn_config = returnn_config ,
269- checkpoints = checkpoints ,
270- recognition_parameters = recog_params ,
271- lm_configs = lm_configs ,
272- prior_paths = prior_paths ,
273- search_job_args = search_job_args ,
274- lat_2_ctm_args = lat_2_ctm_args ,
275- scorer_args = scorer_args ,
276- optimize_parameters = optimize_parameters ,
277- epochs = epochs ,
278- scorer_hyp_param_name = scorer_hyp_param_name ,
279- optimize_pron_lm_scales = optimize_pron_lm_scales ,
280- forward_output_layer = forward_output_layer ,
281- tf_fwd_input_name = tf_fwd_input_name ,
282- )
283-
284-
285- def tune_lm_scale (
286- decoder : HybridDecoder ,
287- name : str ,
288- returnn_config : Union [returnn .ReturnnConfig , tk .Path ],
289- checkpoints : Dict [int , Union [returnn .Checkpoint , tk .Path ]],
290- lm_configs : Dict [str , LmConfig ],
291- prior_paths : Dict [str , PriorPath ],
292- search_job_args : Union [SearchJobArgs , Dict ],
293- lat_2_ctm_args : Union [Lattice2CtmArgs , Dict ],
294- scorer_args : Union [ScliteScorerArgs , Dict ],
295- optimize_parameters : Union [OptimizeJobArgs , Dict ],
296- epochs : Optional [List [int ]] = None ,
297- scorer_hyp_param_name : str = "hyp" ,
298- optimize_pron_lm_scales : bool = False ,
299- forward_output_layer : str = "output" ,
300- tf_fwd_input_name : str = "tf-fwd-input" ,
301- ):
302- """
303- tunes the LM scale
304-
305- :return:
306- """
307- recog_params = {
308- "tune2" : [
309- DevRecognitionParameters (
310- am_scales = [1.0 ],
311- lm_scales = [12.0 ],
312- prior_scales = [0.3 , 0.5 , 0.7 ],
313- pronunciation_scales = [1.0 ],
314- tdp_scales = [0.1 , 0.5 , 1.0 ],
315- speech_tdps = [],
316- silence_tdps = [],
317- nonspeech_tdps = [],
318- altas = [0.0 ],
319- ),
320- ],
321- }
322-
323- decoder .recognition (
324- name = name ,
325- returnn_config = returnn_config ,
326- checkpoints = checkpoints ,
327- recognition_parameters = recog_params ,
328- lm_configs = lm_configs ,
329- prior_paths = prior_paths ,
330- search_job_args = search_job_args ,
331- lat_2_ctm_args = lat_2_ctm_args ,
332- scorer_args = scorer_args ,
333- optimize_parameters = optimize_parameters ,
334- epochs = epochs ,
335- scorer_hyp_param_name = scorer_hyp_param_name ,
336- optimize_pron_lm_scales = optimize_pron_lm_scales ,
337- forward_output_layer = forward_output_layer ,
338- tf_fwd_input_name = tf_fwd_input_name ,
339- )
340-
341-
342- def tune_search_space (
343- decoder : HybridDecoder ,
344- name : str ,
345- returnn_config : Union [returnn .ReturnnConfig , tk .Path ],
346- checkpoints : Dict [int , Union [returnn .Checkpoint , tk .Path ]],
347- lm_configs : Dict [str , LmConfig ],
348- prior_paths : Dict [str , PriorPath ],
349- search_job_args : Union [SearchJobArgs , Dict ],
350- lat_2_ctm_args : Union [Lattice2CtmArgs , Dict ],
351- scorer_args : Union [ScliteScorerArgs , Dict ],
352- optimize_parameters : Union [OptimizeJobArgs , Dict ],
353- epochs : Optional [List [int ]] = None ,
354- scorer_hyp_param_name : str = "hyp" ,
355- optimize_pron_lm_scales : bool = False ,
356- forward_output_layer : str = "output" ,
357- tf_fwd_input_name : str = "tf-fwd-input" ,
358- ):
359- """
360- tunes beam search size and altas
361-
362- :return:
363- """
364- recog_params = DevRecognitionParameters ()
365-
366- decoder .recognition ()
367-
368-
369- def tune_beam_pruning_limit (
370- decoder : HybridDecoder ,
371- name : str ,
372- returnn_config : Union [returnn .ReturnnConfig , tk .Path ],
373- checkpoints : Dict [int , Union [returnn .Checkpoint , tk .Path ]],
374- lm_configs : Dict [str , LmConfig ],
375- prior_paths : Dict [str , PriorPath ],
376- search_job_args : Union [SearchJobArgs , Dict ],
377- lat_2_ctm_args : Union [Lattice2CtmArgs , Dict ],
378- scorer_args : Union [ScliteScorerArgs , Dict ],
379- optimize_parameters : Union [OptimizeJobArgs , Dict ],
380- epochs : Optional [List [int ]] = None ,
381- scorer_hyp_param_name : str = "hyp" ,
382- optimize_pron_lm_scales : bool = False ,
383- forward_output_layer : str = "output" ,
384- tf_fwd_input_name : str = "tf-fwd-input" ,
385- ):
386- """
387- tunes the beam pruning limit
388-
389- :return:
390- """
391- recog_params = DevRecognitionParameters ()
392-
393- decoder .recognition ()
394-
395-
396- def tune_decoding (
397- name : str ,
398- * ,
399- rasr_binary_path : tk .Path ,
400- acoustic_model_config : AmRasrConfig ,
401- lexicon_config : LexiconRasrConfig ,
402- returnn_config : Union [returnn .ReturnnConfig , tk .Path ],
403- checkpoints : Dict [int , Union [returnn .Checkpoint , tk .Path ]],
404- lm_configs : Dict [str , LmConfig ],
405- prior_paths : Dict [str , PriorPath ],
406- search_job_args : Union [SearchJobArgs , Dict ],
407- lat_2_ctm_args : Union [Lattice2CtmArgs , Dict ],
408- scorer_args : Union [ScliteScorerArgs , Dict ],
409- optimize_parameters : Union [OptimizeJobArgs , Dict ],
410- rasr_arch : str = "linux-x86_64-standard" ,
411- compress : bool = False ,
412- append : bool = False ,
413- unbuffered : bool = False ,
414- compress_after_run : bool = True ,
415- search_job_class : Type [tk .Job ] = recog .AdvancedTreeSearchJob ,
416- scorer_job_class : Type [tk .Job ] = recog .ScliteJob ,
417- alias_output_prefix : str = "" ,
418- returnn_root : Optional [tk .Path ] = None ,
419- returnn_python_home : Optional [tk .Path ] = None ,
420- returnn_python_exe : Optional [tk .Path ] = None ,
421- blas_lib : Optional [tk .Path ] = None ,
422- search_numpy_blas : bool = True ,
423- required_native_ops : Optional [List [str ]] = None ,
424- extra_configs : Optional [Dict [str , rasr .RasrConfig ]] = None ,
425- crp_name : str = "base" ,
426- epochs : Optional [List [int ]] = None ,
427- scorer_hyp_param_name : str = "hyp" ,
428- optimize_pron_lm_scales : bool = False ,
429- forward_output_layer : str = "output" ,
430- tf_fwd_input_name : str = "tf-fwd-input" ,
431- ):
432- """
433- 1. TDPs, scales: prior, and TDP [beam-pruning = 14.0, altas = 12.0]
434- a. TDP: {0.1, 0.5, 1.0}
435- b. Prior: {0.3, 0.5, 0.7}
436- c. Silence and non-word phon: {0.0, 4.0, 10.0}
437- 2. LM scale optimization
438- a. no altas
439- b. beam-pruning: 14.0, 15.0
440- 3.
441- a. beam-pruning: 14.0, 15.0
442- b. altas: 2.0, 4.0, 6.0, 8.0
443- 4. beam pruning-limit: 15k, 10k, 7.5k, 6k, 5k, 4k
444-
445- :return:
446- """
447- decoder = HybridDecoder (
448- rasr_binary_path = rasr_binary_path ,
449- rasr_arch = rasr_arch ,
450- compress = compress ,
451- append = append ,
452- unbuffered = unbuffered ,
453- compress_after_run = compress_after_run ,
454- search_job_class = search_job_class ,
455- scorer_job_class = scorer_job_class ,
456- alias_output_prefix = alias_output_prefix ,
457- returnn_root = returnn_root ,
458- returnn_python_home = returnn_python_home ,
459- returnn_python_exe = returnn_python_exe ,
460- blas_lib = blas_lib ,
461- search_numpy_blas = search_numpy_blas ,
462- required_native_ops = required_native_ops ,
463- )
464- decoder .init_decoder (
465- acoustic_model_config = acoustic_model_config ,
466- lexicon_config = lexicon_config ,
467- extra_configs = extra_configs ,
468- crp_name = crp_name ,
469- )
0 commit comments