Skip to content

Commit 7299ebb

Browse files
committed
update
1 parent 286a2ab commit 7299ebb

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

users/zeineldeen/experiments/conformer_att_2022/librispeech_960/attention_asr_config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -846,8 +846,12 @@ def create_config(
846846
else:
847847
raise ValueError("Invalid speed_pert_version")
848848

849-
if feature_extraction_net and global_stats:
850-
add_global_stats_norm(global_stats, exp_config["network"])
849+
if feature_extraction_net:
850+
if global_stats:
851+
add_global_stats_norm(global_stats, exp_config["network"])
852+
else:
853+
# use per-seq norm
854+
add_per_seq_norm(exp_config["network"])
851855

852856
if mixup_aug_opts:
853857
add_mixup_layers(
@@ -878,6 +882,8 @@ def create_config(
878882
net.update(feature_extraction_net)
879883
if global_stats:
880884
add_global_stats_norm(global_stats, net)
885+
else:
886+
add_per_seq_norm(net)
881887
if mixup_aug_opts and enable_mixup_in_pretrain:
882888
add_mixup_layers(net, feature_extraction_net, mixup_aug_opts, is_recog)
883889
net_as_str = "from returnn.config import get_global_config\n"
@@ -1011,6 +1017,11 @@ def create_config(
10111017
return returnn_config
10121018

10131019

1020+
def add_per_seq_norm(net):
1021+
net["log10_"] = copy.deepcopy(net["log10"])
1022+
net["log10"] = {"class": "norm", "from": "log10_", "axis": "T"}
1023+
1024+
10141025
def add_global_stats_norm(global_stats, net):
10151026
if isinstance(global_stats, dict):
10161027
from sisyphus.delayed_ops import DelayedFormat

0 commit comments

Comments
 (0)