Skip to content

Commit 8e7f015

Browse files
author
luca.gaudino
committed
masking fix experiments
1 parent b6fbbb7 commit 8e7f015

File tree

3 files changed

+129
-178
lines changed

3 files changed

+129
-178
lines changed

users/gaudino/experiments/conformer_att_2023/librispeech_960/attention_asr_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,8 @@ class CTCDecoderArgs(DecoderArgs):
516516
ctc_beam_search_tf: bool = False
517517
att_masking_fix: bool = False
518518
one_minus_term_mul_scale: float = 1.0
519+
one_minus_term_sub_scale: float = 0.0
520+
length_normalization: bool = False
519521

520522

521523
def create_config(

users/gaudino/experiments/conformer_att_2023/librispeech_960/configs/ctc_att_search.py

Lines changed: 93 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ def train_mini_self_att(
943943
)
944944

945945
# ctc + lm
946-
for beam_size in [55]:
946+
for beam_size in []:
947947
for lm_type in ["lstm"]: # "trafo" "lstm"
948948
for scale in [(0.5, 1)]:
949949
search_args = copy.deepcopy(oclr_args)
@@ -1002,7 +1002,7 @@ def train_mini_self_att(
10021002
)
10031003

10041004
# ctc + att
1005-
for beam_size in [12]:
1005+
for beam_size in []:
10061006
for scale in [(1, 0.1)]:
10071007
search_args = copy.deepcopy(oclr_args)
10081008
search_args["beam_size"] = beam_size
@@ -1086,7 +1086,7 @@ def train_mini_self_att(
10861086
# TODO: one-pass joint decoding with CTC
10871087

10881088
for comb_score_version in [2]:
1089-
for beam_size in [12]:
1089+
for beam_size in []:
10901090
for scale in [(0.3, 1.0)]:
10911091
att_scale, ctc_scale = scale
10921092
exp_name = f"joint_att_ctc_attScale{att_scale}_ctcScale{ctc_scale}_beam{beam_size}_combScoreV{comb_score_version}_fixRepeat"
@@ -1220,10 +1220,10 @@ def train_mini_self_att(
12201220
use_sclite=True,
12211221
)
12221222
if mode == "att":
1223-
for prior_scale in [0.15, 0.2, 0.25]:
1223+
for prior_scale in []:
12241224
att_scale, ctc_scale = (0.65, 0.35)
12251225
exp_name = (
1226-
f"ctc_decoder_attScale{att_scale}_ctcScale{ctc_scale}_beam_12_priorScale_{prior_scale}_maskfix"
1226+
f"ctc_decoder_attScale{att_scale}_ctcScale{ctc_scale}_beam_32_priorScale_{prior_scale}_maskfix"
12271227
)
12281228
search_args = copy.deepcopy(prior_corr_args)
12291229
search_args["beam_size"] = 32
@@ -1359,7 +1359,7 @@ def train_mini_self_att(
13591359
)
13601360

13611361
# test remove_eos
1362-
for mode in ["greedy", "att", "lstm_lm"]: # ["greedy", "att", "lstm_lm"]
1362+
for mode in []: # ["greedy", "att", "lstm_lm"]
13631363
if mode == "greedy":
13641364
search_args = copy.deepcopy(oclr_args)
13651365
search_args["decoder_args"] = CTCDecoderArgs(remove_eos=True, add_eos_to_blank=True)
@@ -1450,136 +1450,8 @@ def train_mini_self_att(
14501450
use_sclite=True,
14511451
)
14521452

1453-
# test blank scale + repeat prob scale
1454-
for mode in ["att", "lstm_lm"]:
1455-
search_args = copy.deepcopy(oclr_args)
1456-
if mode == "att":
1457-
for blank_scale in [1.0]:
1458-
for repeat_scale in [-0.5]:
1459-
att_scale, ctc_scale = (0.3, 1.0)
1460-
exp_name = f"ctc_decoder_attScale{att_scale}_ctcScale{ctc_scale}_beam_12_blankScale_{blank_scale}_repeatScale_{repeat_scale}"
1461-
search_args = copy.deepcopy(oclr_args)
1462-
search_args["beam_size"] = 12
1463-
search_args["decoder_args"] = CTCDecoderArgs(
1464-
add_att_dec=True,
1465-
att_scale=att_scale,
1466-
ctc_scale=ctc_scale,
1467-
blank_prob_scale=blank_scale,
1468-
repeat_prob_scale=repeat_scale,
1469-
)
1470-
run_decoding(
1471-
exp_name=exp_name,
1472-
train_data=train_data,
1473-
checkpoint=train_job_avg_ckpt[
1474-
f"base_conf_12l_lstm_1l_conv6_OCLR_sqrdReLU_cyc915_ep2035_peak0.0009_retrain1_const20_linDecay580_{1e-4}"
1475-
],
1476-
search_args=search_args,
1477-
feature_extraction_net=log10_net_10ms,
1478-
bpe_size=BPE_10K,
1479-
test_sets=["dev-other"],
1480-
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1481-
use_sclite=True,
1482-
time_rqmt=1.0 if beam_size <= 128 else 1.5,
1483-
)
1484-
if mode == "lstm_lm":
1485-
for lm_scale in [0.5]:
1486-
for repeat_scale in [-1, -0.5]:
1487-
ctc_scale = 1.0
1488-
lm_type = "lstm"
1489-
ext_lm_opts = lstm_lm_opts_map[BPE_10K]
1490-
time_rqmt = 1.0
1491-
beam_size = 55
1492-
blank_scale = 1.0
1493-
1494-
search_args["decoder_args"] = CTCDecoderArgs(
1495-
add_ext_lm=True,
1496-
lm_type=lm_type,
1497-
ext_lm_opts=ext_lm_opts,
1498-
lm_scale=lm_scale,
1499-
ctc_scale=ctc_scale,
1500-
blank_prob_scale=blank_scale,
1501-
repeat_prob_scale=repeat_scale,
1502-
)
1503-
search_args["beam_size"] = beam_size
1504-
run_decoding(
1505-
exp_name=f"ctc_{ctc_scale}_{lm_type}_{lm_scale}_beam_{beam_size}_blankScale_{blank_scale}_repeatScale_{repeat_scale}",
1506-
train_data=train_data,
1507-
checkpoint=train_job_avg_ckpt[
1508-
f"base_conf_12l_lstm_1l_conv6_OCLR_sqrdReLU_cyc915_ep2035_peak0.0009_retrain1_const20_linDecay580_{1e-4}"
1509-
],
1510-
search_args=search_args,
1511-
feature_extraction_net=log10_net_10ms,
1512-
bpe_size=BPE_10K,
1513-
test_sets=["dev-other"],
1514-
time_rqmt=time_rqmt,
1515-
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1516-
use_sclite=True,
1517-
)
1518-
1519-
# test ts_reward
1520-
for mode in ["att", "lstm_lm"]:
1521-
search_args = copy.deepcopy(oclr_args)
1522-
if mode == "att":
1523-
for ts_reward in [1.0, 1.3, 1.5, 2.0]:
1524-
att_scale, ctc_scale = (0.3, 1.0)
1525-
exp_name = f"ctc_decoder_attScale{att_scale}_ctcScale{ctc_scale}_beam_12_tsReward{ts_reward}"
1526-
search_args = copy.deepcopy(oclr_args)
1527-
search_args["beam_size"] = 12
1528-
search_args["decoder_args"] = CTCDecoderArgs(
1529-
add_att_dec=True,
1530-
att_scale=att_scale,
1531-
ctc_scale=ctc_scale,
1532-
ts_reward=ts_reward,
1533-
)
1534-
run_decoding(
1535-
exp_name=exp_name,
1536-
train_data=train_data,
1537-
checkpoint=train_job_avg_ckpt[
1538-
f"base_conf_12l_lstm_1l_conv6_OCLR_sqrdReLU_cyc915_ep2035_peak0.0009_retrain1_const20_linDecay580_{1e-4}"
1539-
],
1540-
search_args=search_args,
1541-
feature_extraction_net=log10_net_10ms,
1542-
bpe_size=BPE_10K,
1543-
test_sets=["dev-other"],
1544-
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1545-
use_sclite=True,
1546-
time_rqmt=1.0 if beam_size <= 128 else 1.5,
1547-
)
1548-
if mode == "lstm_lm":
1549-
for ts_reward in [1.0, 1.5, 2.0]:
1550-
ctc_scale = 1.0
1551-
lm_type = "lstm"
1552-
ext_lm_opts = lstm_lm_opts_map[BPE_10K]
1553-
time_rqmt = 1.0
1554-
beam_size = 55
1555-
blank_scale = 1.0
1556-
1557-
search_args["decoder_args"] = CTCDecoderArgs(
1558-
add_ext_lm=True,
1559-
lm_type=lm_type,
1560-
ext_lm_opts=ext_lm_opts,
1561-
lm_scale=lm_scale,
1562-
ctc_scale=ctc_scale,
1563-
ts_reward=ts_reward,
1564-
)
1565-
search_args["beam_size"] = beam_size
1566-
run_decoding(
1567-
exp_name=f"ctc_{ctc_scale}_{lm_type}_{lm_scale}_beam_{beam_size}_tsReward_{ts_reward}",
1568-
train_data=train_data,
1569-
checkpoint=train_job_avg_ckpt[
1570-
f"base_conf_12l_lstm_1l_conv6_OCLR_sqrdReLU_cyc915_ep2035_peak0.0009_retrain1_const20_linDecay580_{1e-4}"
1571-
],
1572-
search_args=search_args,
1573-
feature_extraction_net=log10_net_10ms,
1574-
bpe_size=BPE_10K,
1575-
test_sets=["dev-other"],
1576-
time_rqmt=time_rqmt,
1577-
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1578-
use_sclite=True,
1579-
)
1580-
15811453
# ctc + att masking fix sanity check
1582-
for beam_size in [12, 32, 64]:
1454+
for beam_size in [32]:
15831455
for scale in [(0.65, 0.35)]:
15841456
search_args = copy.deepcopy(oclr_args)
15851457
search_args["beam_size"] = beam_size
@@ -1598,26 +1470,66 @@ def train_mini_self_att(
15981470
feature_extraction_net=log10_net_10ms,
15991471
bpe_size=BPE_10K,
16001472
test_sets=["dev-other"],
1473+
# test_sets=["dev-clean", "dev-other", "test-clean", "test-other"],
16011474
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1602-
use_sclite=False,
1475+
use_sclite=True,
16031476
)
16041477

1605-
# ctc att mask fix + scales
1606-
for beam_size in [32]:
1607-
for omt_mul in [0.0, 0.5]:
1478+
# ctc + att masking fix large beam
1479+
for beam_size in [256, 512]:
1480+
for scale in [(0.65, 0.35), (0.67, 0.33), (0.63, 0.37)]:
16081481
search_args = copy.deepcopy(oclr_args)
16091482
search_args["beam_size"] = beam_size
1610-
att_scale, ctc_scale = (0.65, 0.35)
1483+
search_args["batch_size"] = 4000 * 160
1484+
att_scale, ctc_scale = scale
16111485

1486+
search_args["decoder_args"] = CTCDecoderArgs(
1487+
add_att_dec=True, att_scale=att_scale, ctc_scale=ctc_scale, att_masking_fix=True
1488+
)
1489+
run_decoding(
1490+
exp_name=f"ctc_{ctc_scale}_att_{att_scale}_beam{beam_size}_masking_fix",
1491+
train_data=train_data,
1492+
checkpoint=train_job_avg_ckpt[
1493+
f"base_conf_12l_lstm_1l_conv6_OCLR_sqrdReLU_cyc915_ep2035_peak0.0009_retrain1_const20_linDecay580_{1e-4}"
1494+
],
1495+
search_args=search_args,
1496+
feature_extraction_net=log10_net_10ms,
1497+
bpe_size=BPE_10K,
1498+
test_sets=["dev-other"],
1499+
# test_sets=["dev-clean", "dev-other", "test-clean", "test-other"],
1500+
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1501+
use_sclite=True,
1502+
time_rqmt=3.0,
1503+
)
1504+
1505+
# ctc att mask fix + lm
1506+
for beam_size in [32]:
1507+
prior_corr_args = copy.deepcopy(oclr_args)
1508+
prior_corr_args[
1509+
"ctc_log_prior_file"
1510+
] = "/work/asr3/zeineldeen/hiwis/luca.gaudino/setups-data/2023-02-22--conformer-swb/work/i6_core/returnn/extract_prior/ReturnnComputePriorJobV2.ZdcvhAOyWl95/output/prior.txt"
1511+
# ] = "/u/luca.gaudino/debug/ctc/prior.txt"
1512+
for scale in [(0.65, 0.35, 0.33)]:
1513+
search_args = copy.deepcopy(oclr_args)
1514+
search_args["beam_size"] = beam_size
1515+
att_scale, ctc_scale, lm_scale = scale
1516+
# prior_scale = 0.3
1517+
lm_type = "lstm"
1518+
ext_lm_opts = lstm_lm_opts_map[BPE_10K]
16121519
search_args["decoder_args"] = CTCDecoderArgs(
16131520
add_att_dec=True,
16141521
att_scale=att_scale,
16151522
ctc_scale=ctc_scale,
16161523
att_masking_fix=True,
1617-
one_minus_term_mul_scale=omt_mul,
1524+
# ctc_prior_correction=True,
1525+
# prior_scale=prior_scale,
1526+
add_ext_lm=True,
1527+
lm_type=lm_type,
1528+
ext_lm_opts=ext_lm_opts,
1529+
lm_scale=lm_scale,
16181530
)
16191531
run_decoding(
1620-
exp_name=f"ctc_{ctc_scale}_att_{att_scale}_beam{beam_size}_masking_fix_omt{omt_mul}",
1532+
exp_name=f"ctc_{ctc_scale}_att_{att_scale}_lm_{lm_scale}_beam{beam_size}_masking_fix",
16211533
train_data=train_data,
16221534
checkpoint=train_job_avg_ckpt[
16231535
f"base_conf_12l_lstm_1l_conv6_OCLR_sqrdReLU_cyc915_ep2035_peak0.0009_retrain1_const20_linDecay580_{1e-4}"
@@ -1626,23 +1538,51 @@ def train_mini_self_att(
16261538
feature_extraction_net=log10_net_10ms,
16271539
bpe_size=BPE_10K,
16281540
test_sets=["dev-other"],
1541+
# test_sets=["dev-clean", "dev-other", "test-clean", "test-other"],
16291542
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1630-
use_sclite=False,
1543+
use_sclite=True,
1544+
time_rqmt=1.0,
16311545
)
1632-
for blank_scale in [1.0]:
1546+
1547+
# ctc + att masking fix scales
1548+
for beam_size in [32]:
1549+
for scale in [(0.65, 0.35)]:
16331550
search_args = copy.deepcopy(oclr_args)
16341551
search_args["beam_size"] = beam_size
1635-
att_scale, ctc_scale = (0.65, 0.35)
1552+
att_scale, ctc_scale = scale
16361553

16371554
search_args["decoder_args"] = CTCDecoderArgs(
1638-
add_att_dec=True,
1639-
att_scale=att_scale,
1640-
ctc_scale=ctc_scale,
1641-
att_masking_fix=True,
1642-
blank_prob_scale=blank_scale,
1555+
add_att_dec=True, att_scale=att_scale, ctc_scale=ctc_scale, att_masking_fix=True,
1556+
one_minus_term_mul_scale=1.5,
1557+
)
1558+
run_decoding(
1559+
exp_name=f"ctc_{ctc_scale}_att_{att_scale}_beam{beam_size}_omt{1.5}",
1560+
train_data=train_data,
1561+
checkpoint=train_job_avg_ckpt[
1562+
f"base_conf_12l_lstm_1l_conv6_OCLR_sqrdReLU_cyc915_ep2035_peak0.0009_retrain1_const20_linDecay580_{1e-4}"
1563+
],
1564+
search_args=search_args,
1565+
feature_extraction_net=log10_net_10ms,
1566+
bpe_size=BPE_10K,
1567+
test_sets=["dev-other"],
1568+
# test_sets=["dev-clean", "dev-other", "test-clean", "test-other"],
1569+
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1570+
use_sclite=True,
1571+
)
1572+
1573+
# ctc + att length norm
1574+
for beam_size in [32]:
1575+
for scale in [(0.65, 0.35)]:
1576+
search_args = copy.deepcopy(oclr_args)
1577+
search_args["beam_size"] = beam_size
1578+
att_scale, ctc_scale = scale
1579+
1580+
search_args["decoder_args"] = CTCDecoderArgs(
1581+
add_att_dec=True, att_scale=att_scale, ctc_scale=ctc_scale, att_masking_fix=True,
1582+
length_normalization=True,
16431583
)
16441584
run_decoding(
1645-
exp_name=f"ctc_{ctc_scale}_att_{att_scale}_beam{beam_size}_masking_fix_blank_scale{blank_scale}",
1585+
exp_name=f"ctc_{ctc_scale}_att_{att_scale}_beam{beam_size}_mf_len_norm",
16461586
train_data=train_data,
16471587
checkpoint=train_job_avg_ckpt[
16481588
f"base_conf_12l_lstm_1l_conv6_OCLR_sqrdReLU_cyc915_ep2035_peak0.0009_retrain1_const20_linDecay580_{1e-4}"
@@ -1651,6 +1591,7 @@ def train_mini_self_att(
16511591
feature_extraction_net=log10_net_10ms,
16521592
bpe_size=BPE_10K,
16531593
test_sets=["dev-other"],
1594+
# test_sets=["dev-clean", "dev-other", "test-clean", "test-other"],
16541595
remove_label={"<s>", "<blank>"}, # blanks are removed in the network
1655-
use_sclite=False,
1596+
use_sclite=True,
16561597
)

0 commit comments

Comments
 (0)