Skip to content

Commit 918e20e

Browse files
committed
before starting with Ted
1 parent 7299ebb commit 918e20e

File tree

6 files changed

+29
-9
lines changed

6 files changed

+29
-9
lines changed

users/raissi/experiments/librispeech/configs/LFR_factored/baseline/alignment/config_alignment_lfr.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_system(key, lr=4e-4, num_epochs=None, am_scale=1.0, tdp_scale=0.1):
7272
label_info_init_args = {
7373
"ph_emb_size": 0,
7474
"st_emb_size": 0,
75-
"state_tying": RasrStateTying.monophone,
75+
"state_tying": 'monophone-dense',#RasrStateTying.monophone,
7676
"n_states_per_phone": 1
7777
}
7878
init_args_system = {
@@ -174,6 +174,7 @@ def get_system(key, lr=4e-4, num_epochs=None, am_scale=1.0, tdp_scale=0.1):
174174
returnn_config=s.experiments[key]["returnn_config"],
175175
log_linear_scales=log_linear_scales
176176
)
177+
177178
s.experiments[key]["returnn_config"] = bw_augmented_returnn_config
178179

179180
s.returnn_rasr_training_fullsum(
@@ -183,6 +184,8 @@ def get_system(key, lr=4e-4, num_epochs=None, am_scale=1.0, tdp_scale=0.1):
183184
nn_train_args=train_args,
184185
)
185186

187+
s.label_info = dataclasses.replace(s.label_info, state_tying=RasrStateTying.monophone)
188+
186189
return_config_dict_infer = s.get_config_with_legacy_prolog_and_epilog(
187190
config=s.experiments[key]["returnn_config"].config,
188191
epilog_additional_str=train_helpers.specaugment.get_legacy_specaugment_epilog_blstm(
@@ -194,10 +197,10 @@ def get_system(key, lr=4e-4, num_epochs=None, am_scale=1.0, tdp_scale=0.1):
194197

195198
s.set_single_prior_returnn_rasr(
196199
key=key,
197-
epoch=450,
200+
epoch=400,
198201
train_corpus_key=s.crp_names["train"],
199202
dev_corpus_key=s.crp_names["cvtrain"],
200-
data_share=0.3,
203+
data_share=0.1,
201204
context_type=PhoneticContext.monophone,
202205
smoothen=True,
203206
output_layer_name="center-output"

users/raissi/setups/common/TF_factored_hybrid_system.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def get_conformer_network(
211211
network = net_helpers.augment.augment_net_with_label_pops(
212212
network, label_info=self.label_info, frame_rate_reduction_ratio_info=frame_rate_reduction_ratio_info
213213
)
214-
if frame_rate_reduction_ratio_info.factor > 1:
214+
if frame_rate_reduction_ratio_info.factor > 1 and frame_rate_reduction_ratio_info.single_state_alignment:
215215
network["slice_classes"] = {
216216
"class": "slice",
217217
"from": network["classes_"]["from"],

users/raissi/setups/common/encoder/conformer/best_setup.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
from enum import Enum
5-
from typing import Union, Optional
5+
from typing import Union, Optional, Tuple
66
from i6_experiments.users.raissi.setups.common.encoder.conformer.get_network_args import (
77
get_encoder_args,
88
get_network_args,
@@ -29,7 +29,7 @@ def get_best_model_config(
2929
num_classes: int,
3030
num_input_feature: int,
3131
*,
32-
chunking: Optional[str] = None,
32+
chunking: [str, Tuple] = None,
3333
int_loss_at_layer: Optional[int] = None,
3434
int_loss_scale: Optional[float] = None,
3535
label_smoothing: Optional[float] = None,
@@ -52,7 +52,10 @@ def get_best_model_config(
5252

5353
assert model_dim % att_dim == 0, "model_dim must be divisible by number of att heads"
5454

55-
clipping, overlap = [int(v) for v in chunking.split(":")] if chunking is not None else (400, 200)
55+
if isinstance(chunking, tuple):
56+
[clipping, overlap] = [ele['data'] for ele in chunking]
57+
else:
58+
clipping, overlap = [int(v) for v in chunking.split(":")] if chunking is not None else (400, 200)
5659

5760
enc_args = get_encoder_args(
5861
model_dim // att_dim,

users/raissi/setups/common/helpers/train/chunking.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ def chunking_with_nfactor(
55
chunk_str: str, factor: int, data_key: str = "data", class_key: str = "classes"
66
) -> Tuple[Dict[str, int], Dict[str, int]]:
77
"""
8-
It gives back the cunking dictionary for different factors. Factor 1 means no subsampling is done
8+
It gives back the chunking dictionary for different factors. Factor 1 means no subsampling is done
99
"""
10+
1011
chunk, overlap = [int(p.strip()) for p in chunk_str.strip().split(":")]
11-
return ({"classes": chunk // factor, "data": chunk}, {"classes": overlap // factor, "data": overlap})
12+
return ({class_key: chunk // factor, data_key: chunk}, {class_key: overlap // factor, data_key: overlap})

users/raissi/setups/common/helpers/train/network_params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,7 @@ def __post_init__(self):
4040
# no chunking for full-sum
4141
default_blstm_fullsum = GeneralNetworkParams(l2=1e-4, use_multi_task=False, add_mlps=False)
4242
default_conformer_viterbi = GeneralNetworkParams(chunking="400:200", l2=1e-6, specaug_args=asdict(default_sa_args))
43+
44+
frameshift40_conformer_viterbi = GeneralNetworkParams(
45+
l2=1e-6, chunking="400:200", specaug_args=asdict(default_sa_args), frame_rate_reduction_ratio_factor=4
46+
)

users/raissi/utils/default_tools.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def get_rasr_binary_path(rasr_path):
7878

7979
# common
8080
RETURNN_ROOT = tk.Path("/work/tools/users/raissi/returnn_versions/conformer", hash_overwrite="CONFORMER_RETURNN_ROOT")
81+
RETURNN_ROOT_MORITZ = tk.Path("/work/asr3/raissi/shared_workspaces/gunz/2023-05--thesis-baselines-tf2/i6_core/tools/git/CloneGitRepositoryJob.0TxYoqLkxbuC/output/returnn", hash_overwrite="CONFORMER_RETURNN_Len_FIX")
8182
RETURNN_ROOT_TORCH = tk.Path("/work/tools/users/raissi/returnn_versions/torch", hash_overwrite="TORCH_RETURNN_ROOT")
8283

8384
SCTK_BINARY_PATH = compile_sctk(branch="v2.4.12") # use last published version
@@ -104,6 +105,12 @@ def __post_init__(self) -> None:
104105
rasr_binary_path=U16_RASR_BINARY_PATHS["TF2"],
105106
)
106107

108+
u16_default_tools_returnn_fix = ToolPaths(
109+
returnn_root=RETURNN_ROOT_MORITZ,
110+
returnn_python_exe=U16_RETURNN_LAUNCHERS["TF2"],
111+
rasr_binary_path=U16_RASR_BINARY_PATHS["TED_COMMON"],
112+
)
113+
107114

108115
u16_default_tools_ted = ToolPaths(
109116
returnn_root=RETURNN_ROOT,
@@ -112,6 +119,8 @@ def __post_init__(self) -> None:
112119
)
113120

114121

122+
123+
115124
u22_tools_tf = ToolPaths(
116125
returnn_root=RETURNN_ROOT_TORCH,
117126
returnn_python_exe=U22_RETURNN_LAUNCHERS["TF2"],

0 commit comments

Comments
 (0)