@@ -658,9 +658,8 @@ def recognize(
658658 name += f"-prJ-C{ search_parameters .prior_info .diphone_prior .scale } "
659659 if search_parameters .we_pruning > 0.5 :
660660 name += f"-wep{ search_parameters .we_pruning } "
661- if search_parameters .we_pruning_limit < 5000 :
661+ if search_parameters .we_pruning_limit < 5000 or search_parameters . we_pruning_limit > 10000 :
662662 # condition for rtf
663- name += f"-wep{ search_parameters .we_pruning } "
664663 name += f"-wepLim{ search_parameters .we_pruning_limit } "
665664 if search_parameters .altas is not None :
666665 name += f"-ALTAS{ search_parameters .altas } "
@@ -705,10 +704,8 @@ def recognize(
705704 tdp_silence = (
706705 search_parameters .tdp_silence if search_parameters .tdp_scale is not None else (0.0 , 0.0 , "infinity" , 0.0 )
707706 )
708- tdp_non_word = (
709- search_parameters .tdp_non_word
710- if search_parameters .tdp_non_word is not None
711- else (0.0 , 0.0 , "infinity" , 0.0 )
707+ tdp_nonword = (
708+ search_parameters .tdp_nonword if search_parameters .tdp_nonword is not None else (0.0 , 0.0 , "infinity" , 0.0 )
712709 )
713710
714711 search_crp .acoustic_model_config = am .acoustic_model_config (
@@ -720,7 +717,7 @@ def recognize(
720717 tdp_scale = search_parameters .tdp_scale ,
721718 tdp_transition = tdp_transition ,
722719 tdp_silence = tdp_silence ,
723- tdp_nonword = tdp_non_word ,
720+ tdp_nonword = tdp_nonword ,
724721 nonword_phones = search_parameters .non_word_phonemes ,
725722 tying_type = "global-and-nonword" ,
726723 )
@@ -1152,6 +1149,205 @@ def push_delayed_tuple(
11521149 left = best_left_prior ,
11531150 right = best_right_prior ,
11541151 )
1152+
1153+ def recognize_optimize_scales_v2 (
1154+ self ,
1155+ * ,
1156+ label_info : LabelInfo ,
1157+ num_encoder_output : int ,
1158+ search_parameters : SearchParameters ,
1159+ prior_scales : Union [
1160+ List [Tuple [float ]], # center
1161+ List [Tuple [float , float ]], # center, left
1162+ List [Tuple [float , float , float ]], # center, left, right
1163+ np .ndarray ,
1164+ ],
1165+ tdp_scales : Union [List [float ], np .ndarray ],
1166+ tdp_sil : Optional [List [Tuple [TDP , TDP , TDP , TDP ]]] = None ,
1167+ tdp_nonword : Optional [List [Tuple [TDP , TDP , TDP , TDP ]]] = None ,
1168+ tdp_speech : Optional [List [Tuple [TDP , TDP , TDP , TDP ]]] = None ,
1169+ pron_scales : Union [List [float ], np .ndarray ] = None ,
1170+ altas_value = 14.0 ,
1171+ altas_beam = 14.0 ,
1172+ keep_value = 10 ,
1173+ gpu : Optional [bool ] = None ,
1174+ cpu_rqmt : Optional [int ] = None ,
1175+ mem_rqmt : Optional [int ] = None ,
1176+ crp_update : Optional [Callable [[rasr .RasrConfig ], Any ]] = None ,
1177+ pre_path : str = "scales" ,
1178+ cpu_slow : bool = True ,
1179+ ) -> SearchParameters :
1180+ assert len (prior_scales ) > 0
1181+ assert len (tdp_scales ) > 0
1182+
1183+ recog_args = dataclasses .replace (search_parameters , altas = altas_value , beam = altas_beam )
1184+
1185+ if isinstance (prior_scales , np .ndarray ):
1186+ prior_scales = [(s ,) for s in prior_scales ] if prior_scales .ndim == 1 else [tuple (s ) for s in prior_scales ]
1187+
1188+ prior_scales = [tuple (round (p , 2 ) for p in priors ) for priors in prior_scales ]
1189+ prior_scales = [
1190+ (p , 0.0 , 0.0 )
1191+ if isinstance (p , float )
1192+ else (p [0 ], 0.0 , 0.0 )
1193+ if len (p ) == 1
1194+ else (p [0 ], p [1 ], 0.0 )
1195+ if len (p ) == 2
1196+ else p
1197+ for p in prior_scales
1198+ ]
1199+ tdp_scales = [round (s , 2 ) for s in tdp_scales ]
1200+ tdp_sil = tdp_sil if tdp_sil is not None else [recog_args .tdp_silence ]
1201+ tdp_nonword = tdp_nonword if tdp_nonword is not None else [recog_args .tdp_nonword ]
1202+ tdp_speech = tdp_speech if tdp_speech is not None else [recog_args .tdp_speech ]
1203+
1204+ use_pron = self .crp .lexicon_config .normalize_pronunciation and pron_scales is not None
1205+
1206+ if use_pron :
1207+ jobs = {
1208+ ((c , l , r ), tdp , tdp_sl , tdp_nw , tdp_sp , pron ): self .recognize_count_lm (
1209+ add_sis_alias_and_output = False ,
1210+ calculate_stats = False ,
1211+ cpu_rqmt = cpu_rqmt ,
1212+ crp_update = crp_update ,
1213+ gpu = gpu ,
1214+ is_min_duration = False ,
1215+ keep_value = keep_value ,
1216+ label_info = label_info ,
1217+ mem_rqmt = mem_rqmt ,
1218+ name_override = f"{ self .name } -pC{ c } -pL{ l } -pR{ r } -tdp{ tdp } -tdpSil{ tdp_sl } -tdpNnw{ tdp_nw } tdpSp{ tdp_sp } -tdpSp{ tdp_sp } -pron{ pron } " ,
1219+ num_encoder_output = num_encoder_output ,
1220+ opt_lm_am = False ,
1221+ rerun_after_opt_lm = False ,
1222+ search_parameters = dataclasses .replace (
1223+ recog_args ,
1224+ tdp_scale = tdp ,
1225+ tdp_silence = tdp_sl ,
1226+ tdp_nonword = tdp_nw ,
1227+ tdp_speech = tdp_sp ,
1228+ pron_scale = pron ,
1229+ ).with_prior_scale (left = l , center = c , right = r , diphone = c ),
1230+ remove_or_set_concurrency = False ,
1231+ )
1232+ for ((c , l , r ), tdp , tdp_sl , tdp_nw , tdp_sp , pron ) in itertools .product (
1233+ prior_scales , tdp_scales , tdp_sil , tdp_nonword , tdp_speech , pron_scales
1234+ )
1235+ }
1236+ else :
1237+ jobs = {
1238+ ((c , l , r ), tdp , tdp_sl , tdp_nw , tdp_sp ): self .recognize_count_lm (
1239+ add_sis_alias_and_output = False ,
1240+ calculate_stats = False ,
1241+ cpu_rqmt = cpu_rqmt ,
1242+ crp_update = crp_update ,
1243+ gpu = gpu ,
1244+ is_min_duration = False ,
1245+ keep_value = keep_value ,
1246+ label_info = label_info ,
1247+ mem_rqmt = mem_rqmt ,
1248+ name_override = f"{ self .name } -pC{ c } -pL{ l } -pR{ r } -tdp{ tdp } -tdpSil{ tdp_sl } -tdpNnw{ tdp_nw } -tdpSp{ tdp_sp } -" ,
1249+ num_encoder_output = num_encoder_output ,
1250+ opt_lm_am = False ,
1251+ rerun_after_opt_lm = False ,
1252+ search_parameters = dataclasses .replace (
1253+ recog_args , tdp_scale = tdp , tdp_silence = tdp_sl , tdp_nonword = tdp_nw , tdp_speech = tdp_sp
1254+ ).with_prior_scale (left = l , center = c , right = r , diphone = c ),
1255+ remove_or_set_concurrency = False ,
1256+ )
1257+ for ((c , l , r ), tdp , tdp_sl , tdp_nw , tdp_sp ) in itertools .product (
1258+ prior_scales , tdp_scales , tdp_sil , tdp_nonword , tdp_speech
1259+ )
1260+ }
1261+ jobs_num_e = {k : v .scorer .out_num_errors for k , v in jobs .items ()}
1262+
1263+ if use_pron :
1264+ for ((c , l , r ), tdp , tdp_sl , tdp_nw , tdp_sp , pron ), recog_jobs in jobs .items ():
1265+ if cpu_slow :
1266+ recog_jobs .search .update_rqmt ("run" , {"cpu_slow" : True })
1267+
1268+ pre_name = (
1269+ f"{ pre_path } /{ self .name } /Lm{ recog_args .lm_scale } -Pron{ pron } -pC{ c } -pL{ l } -pR{ r } -tdp{ tdp } -"
1270+ f"tdpSil{ format_tdp (tdp_sl )} -tdpNw{ format_tdp (tdp_nw )} -tdpSp{ format_tdp (tdp_sp )} "
1271+ )
1272+
1273+ recog_jobs .lat2ctm .set_keep_value (keep_value )
1274+ recog_jobs .search .set_keep_value (keep_value )
1275+
1276+ recog_jobs .search .add_alias (pre_name )
1277+ tk .register_output (f"{ pre_name } .wer" , recog_jobs .scorer .out_report_dir )
1278+ else :
1279+ for ((c , l , r ), tdp , tdp_sl , tdp_nw , tdp_sp ), recog_jobs in jobs .items ():
1280+ if cpu_slow :
1281+ recog_jobs .search .update_rqmt ("run" , {"cpu_slow" : True })
1282+
1283+ pre_name = (
1284+ f"{ pre_path } /{ self .name } /Lm{ recog_args .lm_scale } -Pron{ recog_args .pron_scale } "
1285+ f"-pC{ c } -pL{ l } -pR{ r } -tdp{ tdp } -tdpSil{ format_tdp (tdp_sl )} -tdpNw{ format_tdp (tdp_nw )} -tdpSp{ format_tdp (tdp_sp )} "
1286+ )
1287+
1288+ recog_jobs .lat2ctm .set_keep_value (keep_value )
1289+ recog_jobs .search .set_keep_value (keep_value )
1290+
1291+ recog_jobs .search .add_alias (pre_name )
1292+ tk .register_output (f"{ pre_name } .wer" , recog_jobs .scorer .out_report_dir )
1293+
1294+ best_overall_wer = ComputeArgminJob ({k : v .scorer .out_wer for k , v in jobs .items ()})
1295+ best_overall_n = ComputeArgminJob (jobs_num_e )
1296+ tk .register_output (
1297+ f"decoding/scales-best/{ self .name } /args" ,
1298+ best_overall_n .out_argmin ,
1299+ )
1300+ tk .register_output (
1301+ f"decoding/scales-best/{ self .name } /wer" ,
1302+ best_overall_wer .out_min ,
1303+ )
1304+
1305+ def push_delayed_tuple (
1306+ argmin : DelayedBase ,
1307+ ) -> Tuple [DelayedBase , DelayedBase , DelayedBase , DelayedBase ]:
1308+ return tuple (argmin [i ] for i in range (4 ))
1309+
1310+ # cannot destructure, need to use indices
1311+ best_priors = best_overall_n .out_argmin [0 ]
1312+ best_tdp_scale = best_overall_n .out_argmin [1 ]
1313+ best_tdp_sil = best_overall_n .out_argmin [2 ]
1314+ best_tdp_sp = best_overall_n .out_argmin [3 ]
1315+ if use_pron :
1316+ best_pron = best_overall_n .out_argmin [4 ]
1317+
1318+ base_cfg = dataclasses .replace (
1319+ search_parameters ,
1320+ tdp_scale = best_tdp_scale ,
1321+ tdp_silence = push_delayed_tuple (best_tdp_sil ),
1322+ tdp_speech = push_delayed_tuple (best_tdp_sp ),
1323+ pron_scale = best_pron ,
1324+ )
1325+ else :
1326+ base_cfg = dataclasses .replace (
1327+ search_parameters ,
1328+ tdp_scale = best_tdp_scale ,
1329+ tdp_silence = push_delayed_tuple (best_tdp_sil ),
1330+ tdp_speech = push_delayed_tuple (best_tdp_sp ),
1331+ )
1332+
1333+ best_center_prior = best_priors [0 ]
1334+ if self .context_type .is_monophone ():
1335+ return base_cfg .with_prior_scale (center = best_center_prior )
1336+ if self .context_type .is_joint_diphone ():
1337+ return base_cfg .with_prior_scale (diphone = best_center_prior )
1338+
1339+ best_left_prior = best_priors [1 ]
1340+ if self .context_type .is_diphone ():
1341+ return base_cfg .with_prior_scale (center = best_center_prior , left = best_left_prior )
1342+
1343+ best_right_prior = best_priors [2 ]
1344+ return base_cfg .with_prior_scale (
1345+ center = best_center_prior ,
1346+ left = best_left_prior ,
1347+ right = best_right_prior ,
1348+ )
1349+
1350+
11551351
11561352
11571353class BASEFactoredHybridAligner (BASEFactoredHybridDecoder ):
@@ -1267,9 +1463,9 @@ def get_alignment_job(
12671463 if alignment_parameters .tdp_scale is not None
12681464 else (0.0 , 0.0 , "infinity" , 0.0 )
12691465 )
1270- tdp_non_word = (
1271- alignment_parameters .tdp_non_word
1272- if alignment_parameters .tdp_non_word is not None
1466+ tdp_nonword = (
1467+ alignment_parameters .tdp_nonword
1468+ if alignment_parameters .tdp_nonword is not None
12731469 else (0.0 , 0.0 , "infinity" , 0.0 )
12741470 )
12751471
@@ -1282,7 +1478,7 @@ def get_alignment_job(
12821478 tdp_scale = alignment_parameters .tdp_scale ,
12831479 tdp_transition = tdp_transition ,
12841480 tdp_silence = tdp_silence ,
1285- tdp_nonword = tdp_non_word ,
1481+ tdp_nonword = tdp_nonword ,
12861482 nonword_phones = alignment_parameters .non_word_phonemes ,
12871483 tying_type = "global-and-nonword" ,
12881484 )
@@ -1345,7 +1541,7 @@ def get_alignment_job(
13451541 if (
13461542 alignment_parameters .tdp_speech [- 1 ]
13471543 + alignment_parameters .tdp_silence [- 1 ]
1348- + alignment_parameters .tdp_non_word [- 1 ]
1544+ + alignment_parameters .tdp_nonword [- 1 ]
13491545 > 0.0
13501546 ):
13511547 import warnings
0 commit comments