Skip to content

Commit fe597c8

Browse files
committed
wip torch and factored priors
1 parent db721e4 commit fe597c8

File tree

27 files changed

+922
-155
lines changed

27 files changed

+922
-155
lines changed

users/raissi/setups/common/BASE_factored_hybrid_system.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,7 @@ def create_hdf(self):
10151015
self.hdfs[self.train_key] = hdf_job.out_hdf_files
10161016

10171017
hdf_job.add_alias(f"hdf/{self.train_key}")
1018+
tk.register_output("hdf/hdf.train.1", hdf_job.out_hdf_files[0])
10181019

10191020
return hdf_job
10201021

users/raissi/setups/common/TF_factored_hybrid_system.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def get_conformer_network_zhou_variant(
278278
}
279279
network["classes_"]["from"] = "slice_classes"
280280

281+
else:
282+
network=encoder_net
283+
281284
return network
282285

283286
# -------------------------------------------- Training --------------------------------------------------------
@@ -733,6 +736,10 @@ def set_diphone_priors_returnn_rasr(
733736

734737
self.experiments[key]["priors"] = p_info
735738

739+
740+
def set_triphone_priors_factored(self):
741+
self.create_hdf()
742+
736743
def set_triphone_priors_returnn_rasr(
737744
self,
738745
key: str,
@@ -830,6 +837,7 @@ def set_triphone_priors_returnn_rasr(
830837
def set_graph_for_experiment(
831838
self, key, override_cfg: Optional[returnn.ReturnnConfig] = None, graph_type_name: Optional[str] = None
832839
):
840+
833841
config = copy.deepcopy(override_cfg if override_cfg is not None else self.experiments[key]["returnn_config"])
834842

835843
name = self.experiments[key]["name"]

users/raissi/setups/common/data/pipeline_helpers.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
__all__ = [
2-
"TrainingCriterion",
3-
"SingleSoftmaxType",
42
"Experiment"
3+
"PriorType",
4+
"SingleSoftmaxType",
5+
"TrainingCriterion",
6+
7+
58
]
69

710

811
from dataclasses import dataclass
9-
from enum import Enum
12+
from enum import Enum, auto
1013
from typing import Optional, TypedDict
1114

1215
import i6_core.mm as mm
@@ -38,6 +41,8 @@ class SingleSoftmaxType(Enum):
3841
def __str__(self):
3942
return self.value
4043

44+
45+
4146
class Experiment(TypedDict):
4247
"""
4348
The class is used in the config files as a single experiment
@@ -56,3 +61,13 @@ class InputKey(Enum):
5661
BASE= "standard-system-input"
5762
HDF = "hdf-input"
5863

64+
65+
class PriorType(Enum):
66+
"""The type of single softmax for joint FH."""
67+
TRANSCRIPT = auto()
68+
AVERAGE = auto()
69+
ONTHEFLY = auto()
70+
71+
def __str__(self):
72+
return self.value
73+

users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py

Lines changed: 208 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -658,9 +658,8 @@ def recognize(
658658
name += f"-prJ-C{search_parameters.prior_info.diphone_prior.scale}"
659659
if search_parameters.we_pruning > 0.5:
660660
name += f"-wep{search_parameters.we_pruning}"
661-
if search_parameters.we_pruning_limit < 5000:
661+
if search_parameters.we_pruning_limit < 5000 or search_parameters.we_pruning_limit > 10000:
662662
# condition for rtf
663-
name += f"-wep{search_parameters.we_pruning}"
664663
name += f"-wepLim{search_parameters.we_pruning_limit}"
665664
if search_parameters.altas is not None:
666665
name += f"-ALTAS{search_parameters.altas}"
@@ -705,10 +704,8 @@ def recognize(
705704
tdp_silence = (
706705
search_parameters.tdp_silence if search_parameters.tdp_scale is not None else (0.0, 0.0, "infinity", 0.0)
707706
)
708-
tdp_non_word = (
709-
search_parameters.tdp_non_word
710-
if search_parameters.tdp_non_word is not None
711-
else (0.0, 0.0, "infinity", 0.0)
707+
tdp_nonword = (
708+
search_parameters.tdp_nonword if search_parameters.tdp_nonword is not None else (0.0, 0.0, "infinity", 0.0)
712709
)
713710

714711
search_crp.acoustic_model_config = am.acoustic_model_config(
@@ -720,7 +717,7 @@ def recognize(
720717
tdp_scale=search_parameters.tdp_scale,
721718
tdp_transition=tdp_transition,
722719
tdp_silence=tdp_silence,
723-
tdp_nonword=tdp_non_word,
720+
tdp_nonword=tdp_nonword,
724721
nonword_phones=search_parameters.non_word_phonemes,
725722
tying_type="global-and-nonword",
726723
)
@@ -1152,6 +1149,205 @@ def push_delayed_tuple(
11521149
left=best_left_prior,
11531150
right=best_right_prior,
11541151
)
1152+
1153+
def recognize_optimize_scales_v2(
1154+
self,
1155+
*,
1156+
label_info: LabelInfo,
1157+
num_encoder_output: int,
1158+
search_parameters: SearchParameters,
1159+
prior_scales: Union[
1160+
List[Tuple[float]], # center
1161+
List[Tuple[float, float]], # center, left
1162+
List[Tuple[float, float, float]], # center, left, right
1163+
np.ndarray,
1164+
],
1165+
tdp_scales: Union[List[float], np.ndarray],
1166+
tdp_sil: Optional[List[Tuple[TDP, TDP, TDP, TDP]]] = None,
1167+
tdp_nonword: Optional[List[Tuple[TDP, TDP, TDP, TDP]]] = None,
1168+
tdp_speech: Optional[List[Tuple[TDP, TDP, TDP, TDP]]] = None,
1169+
pron_scales: Union[List[float], np.ndarray] = None,
1170+
altas_value=14.0,
1171+
altas_beam=14.0,
1172+
keep_value=10,
1173+
gpu: Optional[bool] = None,
1174+
cpu_rqmt: Optional[int] = None,
1175+
mem_rqmt: Optional[int] = None,
1176+
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
1177+
pre_path: str = "scales",
1178+
cpu_slow: bool = True,
1179+
) -> SearchParameters:
1180+
assert len(prior_scales) > 0
1181+
assert len(tdp_scales) > 0
1182+
1183+
recog_args = dataclasses.replace(search_parameters, altas=altas_value, beam=altas_beam)
1184+
1185+
if isinstance(prior_scales, np.ndarray):
1186+
prior_scales = [(s,) for s in prior_scales] if prior_scales.ndim == 1 else [tuple(s) for s in prior_scales]
1187+
1188+
prior_scales = [tuple(round(p, 2) for p in priors) for priors in prior_scales]
1189+
prior_scales = [
1190+
(p, 0.0, 0.0)
1191+
if isinstance(p, float)
1192+
else (p[0], 0.0, 0.0)
1193+
if len(p) == 1
1194+
else (p[0], p[1], 0.0)
1195+
if len(p) == 2
1196+
else p
1197+
for p in prior_scales
1198+
]
1199+
tdp_scales = [round(s, 2) for s in tdp_scales]
1200+
tdp_sil = tdp_sil if tdp_sil is not None else [recog_args.tdp_silence]
1201+
tdp_nonword = tdp_nonword if tdp_nonword is not None else [recog_args.tdp_nonword]
1202+
tdp_speech = tdp_speech if tdp_speech is not None else [recog_args.tdp_speech]
1203+
1204+
use_pron = self.crp.lexicon_config.normalize_pronunciation and pron_scales is not None
1205+
1206+
if use_pron:
1207+
jobs = {
1208+
((c, l, r), tdp, tdp_sl, tdp_nw, tdp_sp, pron): self.recognize_count_lm(
1209+
add_sis_alias_and_output=False,
1210+
calculate_stats=False,
1211+
cpu_rqmt=cpu_rqmt,
1212+
crp_update=crp_update,
1213+
gpu=gpu,
1214+
is_min_duration=False,
1215+
keep_value=keep_value,
1216+
label_info=label_info,
1217+
mem_rqmt=mem_rqmt,
1218+
name_override=f"{self.name}-pC{c}-pL{l}-pR{r}-tdp{tdp}-tdpSil{tdp_sl}-tdpNnw{tdp_nw}tdpSp{tdp_sp}-tdpSp{tdp_sp}-pron{pron}",
1219+
num_encoder_output=num_encoder_output,
1220+
opt_lm_am=False,
1221+
rerun_after_opt_lm=False,
1222+
search_parameters=dataclasses.replace(
1223+
recog_args,
1224+
tdp_scale=tdp,
1225+
tdp_silence=tdp_sl,
1226+
tdp_nonword=tdp_nw,
1227+
tdp_speech=tdp_sp,
1228+
pron_scale=pron,
1229+
).with_prior_scale(left=l, center=c, right=r, diphone=c),
1230+
remove_or_set_concurrency=False,
1231+
)
1232+
for ((c, l, r), tdp, tdp_sl, tdp_nw, tdp_sp, pron) in itertools.product(
1233+
prior_scales, tdp_scales, tdp_sil, tdp_nonword, tdp_speech, pron_scales
1234+
)
1235+
}
1236+
else:
1237+
jobs = {
1238+
((c, l, r), tdp, tdp_sl, tdp_nw, tdp_sp): self.recognize_count_lm(
1239+
add_sis_alias_and_output=False,
1240+
calculate_stats=False,
1241+
cpu_rqmt=cpu_rqmt,
1242+
crp_update=crp_update,
1243+
gpu=gpu,
1244+
is_min_duration=False,
1245+
keep_value=keep_value,
1246+
label_info=label_info,
1247+
mem_rqmt=mem_rqmt,
1248+
name_override=f"{self.name}-pC{c}-pL{l}-pR{r}-tdp{tdp}-tdpSil{tdp_sl}-tdpNnw{tdp_nw}-tdpSp{tdp_sp}-",
1249+
num_encoder_output=num_encoder_output,
1250+
opt_lm_am=False,
1251+
rerun_after_opt_lm=False,
1252+
search_parameters=dataclasses.replace(
1253+
recog_args, tdp_scale=tdp, tdp_silence=tdp_sl, tdp_nonword=tdp_nw, tdp_speech=tdp_sp
1254+
).with_prior_scale(left=l, center=c, right=r, diphone=c),
1255+
remove_or_set_concurrency=False,
1256+
)
1257+
for ((c, l, r), tdp, tdp_sl, tdp_nw, tdp_sp) in itertools.product(
1258+
prior_scales, tdp_scales, tdp_sil, tdp_nonword, tdp_speech
1259+
)
1260+
}
1261+
jobs_num_e = {k: v.scorer.out_num_errors for k, v in jobs.items()}
1262+
1263+
if use_pron:
1264+
for ((c, l, r), tdp, tdp_sl, tdp_nw, tdp_sp, pron), recog_jobs in jobs.items():
1265+
if cpu_slow:
1266+
recog_jobs.search.update_rqmt("run", {"cpu_slow": True})
1267+
1268+
pre_name = (
1269+
f"{pre_path}/{self.name}/Lm{recog_args.lm_scale}-Pron{pron}-pC{c}-pL{l}-pR{r}-tdp{tdp}-"
1270+
f"tdpSil{format_tdp(tdp_sl)}-tdpNw{format_tdp(tdp_nw)}-tdpSp{format_tdp(tdp_sp)}"
1271+
)
1272+
1273+
recog_jobs.lat2ctm.set_keep_value(keep_value)
1274+
recog_jobs.search.set_keep_value(keep_value)
1275+
1276+
recog_jobs.search.add_alias(pre_name)
1277+
tk.register_output(f"{pre_name}.wer", recog_jobs.scorer.out_report_dir)
1278+
else:
1279+
for ((c, l, r), tdp, tdp_sl, tdp_nw, tdp_sp), recog_jobs in jobs.items():
1280+
if cpu_slow:
1281+
recog_jobs.search.update_rqmt("run", {"cpu_slow": True})
1282+
1283+
pre_name = (
1284+
f"{pre_path}/{self.name}/Lm{recog_args.lm_scale}-Pron{recog_args.pron_scale}"
1285+
f"-pC{c}-pL{l}-pR{r}-tdp{tdp}-tdpSil{format_tdp(tdp_sl)}-tdpNw{format_tdp(tdp_nw)}-tdpSp{format_tdp(tdp_sp)}"
1286+
)
1287+
1288+
recog_jobs.lat2ctm.set_keep_value(keep_value)
1289+
recog_jobs.search.set_keep_value(keep_value)
1290+
1291+
recog_jobs.search.add_alias(pre_name)
1292+
tk.register_output(f"{pre_name}.wer", recog_jobs.scorer.out_report_dir)
1293+
1294+
best_overall_wer = ComputeArgminJob({k: v.scorer.out_wer for k, v in jobs.items()})
1295+
best_overall_n = ComputeArgminJob(jobs_num_e)
1296+
tk.register_output(
1297+
f"decoding/scales-best/{self.name}/args",
1298+
best_overall_n.out_argmin,
1299+
)
1300+
tk.register_output(
1301+
f"decoding/scales-best/{self.name}/wer",
1302+
best_overall_wer.out_min,
1303+
)
1304+
1305+
def push_delayed_tuple(
1306+
argmin: DelayedBase,
1307+
) -> Tuple[DelayedBase, DelayedBase, DelayedBase, DelayedBase]:
1308+
return tuple(argmin[i] for i in range(4))
1309+
1310+
# cannot destructure, need to use indices
1311+
best_priors = best_overall_n.out_argmin[0]
1312+
best_tdp_scale = best_overall_n.out_argmin[1]
1313+
best_tdp_sil = best_overall_n.out_argmin[2]
1314+
best_tdp_sp = best_overall_n.out_argmin[3]
1315+
if use_pron:
1316+
best_pron = best_overall_n.out_argmin[4]
1317+
1318+
base_cfg = dataclasses.replace(
1319+
search_parameters,
1320+
tdp_scale=best_tdp_scale,
1321+
tdp_silence=push_delayed_tuple(best_tdp_sil),
1322+
tdp_speech=push_delayed_tuple(best_tdp_sp),
1323+
pron_scale=best_pron,
1324+
)
1325+
else:
1326+
base_cfg = dataclasses.replace(
1327+
search_parameters,
1328+
tdp_scale=best_tdp_scale,
1329+
tdp_silence=push_delayed_tuple(best_tdp_sil),
1330+
tdp_speech=push_delayed_tuple(best_tdp_sp),
1331+
)
1332+
1333+
best_center_prior = best_priors[0]
1334+
if self.context_type.is_monophone():
1335+
return base_cfg.with_prior_scale(center=best_center_prior)
1336+
if self.context_type.is_joint_diphone():
1337+
return base_cfg.with_prior_scale(diphone=best_center_prior)
1338+
1339+
best_left_prior = best_priors[1]
1340+
if self.context_type.is_diphone():
1341+
return base_cfg.with_prior_scale(center=best_center_prior, left=best_left_prior)
1342+
1343+
best_right_prior = best_priors[2]
1344+
return base_cfg.with_prior_scale(
1345+
center=best_center_prior,
1346+
left=best_left_prior,
1347+
right=best_right_prior,
1348+
)
1349+
1350+
11551351

11561352

11571353
class BASEFactoredHybridAligner(BASEFactoredHybridDecoder):
@@ -1267,9 +1463,9 @@ def get_alignment_job(
12671463
if alignment_parameters.tdp_scale is not None
12681464
else (0.0, 0.0, "infinity", 0.0)
12691465
)
1270-
tdp_non_word = (
1271-
alignment_parameters.tdp_non_word
1272-
if alignment_parameters.tdp_non_word is not None
1466+
tdp_nonword = (
1467+
alignment_parameters.tdp_nonword
1468+
if alignment_parameters.tdp_nonword is not None
12731469
else (0.0, 0.0, "infinity", 0.0)
12741470
)
12751471

@@ -1282,7 +1478,7 @@ def get_alignment_job(
12821478
tdp_scale=alignment_parameters.tdp_scale,
12831479
tdp_transition=tdp_transition,
12841480
tdp_silence=tdp_silence,
1285-
tdp_nonword=tdp_non_word,
1481+
tdp_nonword=tdp_nonword,
12861482
nonword_phones=alignment_parameters.non_word_phonemes,
12871483
tying_type="global-and-nonword",
12881484
)
@@ -1345,7 +1541,7 @@ def get_alignment_job(
13451541
if (
13461542
alignment_parameters.tdp_speech[-1]
13471543
+ alignment_parameters.tdp_silence[-1]
1348-
+ alignment_parameters.tdp_non_word[-1]
1544+
+ alignment_parameters.tdp_nonword[-1]
13491545
> 0.0
13501546
):
13511547
import warnings

0 commit comments

Comments
 (0)