Skip to content

Commit b2d2baf

Browse files
Marvin84JudyxujjJingjing Xummz33Simon Berger
authored
Dummy PR (#226)
* Add swb PyTorch ctc setup (#219) * add initial setup * rm binary files * rm binary files --------- Co-authored-by: Jingjing Xu <[email protected]> * add conformer enc with more weight dropout * fix * add more weight noise opts * black formatting * Jing tedlium independent softmax (#220) * tedlium ctc pytorch * rm empty files --------- Co-authored-by: Jingjing Xu <[email protected]> * use ff regs for mhsa out * add more regularized trafo dec * update * add regs to rnn decoder * more * add more regs to rnn dec * black formatting * Update users/berger * update * add readme for RF * more * cleanup, generalize, different spm vocab sizes * more * more * small fix * update ls att+ctc+lm * add args * fix pretraining * Glow-TTS-ASR: Update with fixed invertibility tests * Glow-TTS-ASR update * fix * Glow-TTS-ASR: Cleanup and comments/documentation * small fixes * spm20k * more ctc and rnn-t librispeech experiments * add greedy decoder * black * add ebranchformer * more * more * better layer names for ebranchformer * better * better * decouple mhsa residual * cleanup * refactor args. add ebranch config * more * Update users/berger * Update users/berger * better * config enable write cache manager * standalone 2024 setup add LSTM lm pipeline * update * add horovod to libri pipeline * update configs * fix * more * update * update zoneout fix ted2 * Update users/berger * Update users/berger * more * update * update * Update users/berger * more * more * update * update * ConformerV2 setup * updates * cleanup * updates and fix mel norm + zoneout * update conf v2 * more * update * Update users/berger * fixes and update * add CTC gauss weights * convert ls960 LSTM LM to rf * fix * fix * fix * more * more * use_eos_postfix * fix CTC with EOS recog scoring * ctc eos fix more * update * Update Glow-TTS-ASR * updates quant * added factored bw * deleted wrong stashed * update trainings and initial rnnt decoder rf * feature batch norm * feature normalization * recog fix API doc * collect stats, initial code * librispeech feature stats * feature global norm * small fixes * small fix * small fix * small fix * fix feat norm search * update * update * update * add more weight drop to rnn decoder * add chunked rnn decoder * update * fix * fix * more * more * fix name * more * more * small fix * more * more * cleanup * fix * comment * more * cleanup * more * add canary 1b recog sis prepare config * add config (#223) Co-authored-by: Jingjing Xu <[email protected]> * more * add nemo model download job * add nemo search job * add custom hash * fix * add nemo search * first version of nemo search * better * fix bug * better * add missing search output path * add compute_wer func * add wer as output var * run search for all test sets with canary 1b model * add configs (#224) Co-authored-by: Jingjing Xu <[email protected]> * update * update * register wer as out * update * add libri test other test set * fix args * fix args * update * add modified normalized * Create README.md * Update README.md * Update users/berger * Update users/berger * more * more * more * more * more * more * prepare for some more modeling code * move SequentialLayerDrop * better * move mixup * rnnt dec rf WIP * Update users/berger * update users/raissi monofactored * update * update * ls960 pretrain: use phoneme info for mask boundaries * BatchRenorm initial implementation (untested) * test_piecewise_linear * test_piecewise_linear use dyn_lr_piecewise_linear * dyn_lr_piecewise_linear use RETURNN PiecewiseLinear * DeleteLemmataFromLexiconJob (#225) * ls960 pretrain: phoneme mask and other targets * ls960 pretrain: update num epochs * better * first version of beam search * fix * fix enc shape * use expand instead of repeat for efficiency * better * add hyp postprocessing * better * add beam search * remove print * more * BatchRenorm with build_from_dict * more * small fix * more * small fix * reorder code * comment * prior * cleanup * cache enc beam expansion * fix bug * update * more * more * more * more * LS spm vocab alias * make private * move * lazy, aliases * update and test rf vs torch mhsa * fix warning * fix bug * vocab outputs * more * more, AED featBN, sampling * extract SPM vocab * add rtfs * add cache suffix * update * fix * add debug out * add batch size logging * import i6_models conformer in rf, batch 1 * SamplingBytePairEncoding for SentencePiece * add gradient clipping to example baseline * 2-precision WER and quantization helper * HDF alignment labels example data pipeline * ls960 pretrain: fix python launcher for itc/i6 * latest users/raissi --------- Co-authored-by: Judyxujj <[email protected]> Co-authored-by: Jingjing Xu <[email protected]> Co-authored-by: Mohammad Zeineldeen <[email protected]> Co-authored-by: Simon Berger <[email protected]> Co-authored-by: schmitt <[email protected]> Co-authored-by: Albert Zeyer <[email protected]> Co-authored-by: luca.gaudino <[email protected]> Co-authored-by: Lukas Rilling <[email protected]> Co-authored-by: Nick Rossenbach <[email protected]> Co-authored-by: Benedikt Hilmes <[email protected]> Co-authored-by: Mohammad Zeineldeen <[email protected]> Co-authored-by: Peter Vieting <[email protected]> Co-authored-by: vieting <[email protected]>
1 parent cf5884a commit b2d2baf

File tree

528 files changed

+54982
-17037
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

528 files changed

+54982
-17037
lines changed

example_setups/librispeech/ctc_rnnt_standalone_2024/experiments/ctc_bpe/baseline.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def bpe_ls960_1023_base():
6161
}
6262

6363
from ...pytorch_networks.ctc.decoder.flashlight_ctc_v1 import DecoderConfig
64+
from ...pytorch_networks.ctc.decoder.greedy_bpe_ctc_v3 import DecoderConfig as GreedyDecoderConfig
6465

6566
def tune_and_evaluate_helper(
6667
training_name: str,
@@ -121,6 +122,22 @@ def tune_and_evaluate_helper(
121122
**default_returnn,
122123
)
123124

125+
def greedy_search_helper(training_name: str, asr_model: ASRModel, decoder_config: GreedyDecoderConfig):
126+
# remove prior if exists
127+
asr_model = copy.deepcopy(asr_model)
128+
asr_model.prior_file = None
129+
130+
search_name = training_name + "/search_greedy"
131+
search_jobs, wers = search(
132+
search_name,
133+
forward_config={},
134+
asr_model=asr_model,
135+
decoder_module="ctc.decoder.greedy_bpe_ctc_v3",
136+
decoder_args={"config": asdict(decoder_config)},
137+
test_dataset_tuples=dev_dataset_tuples,
138+
**default_returnn,
139+
)
140+
124141
default_decoder_config_bpe5000 = DecoderConfig(
125142
lexicon=get_text_lexicon(prefix=prefix_name, librispeech_key="train-other-960", bpe_size=5000),
126143
returnn_vocab=label_datastream_bpe5000.vocab,
@@ -200,6 +217,7 @@ def tune_and_evaluate_helper(
200217
"max_seq_length": {"audio_features": 35 * 16000},
201218
"accum_grad_multiple_step": 1,
202219
"torch_amp_options": {"dtype": "bfloat16"},
220+
"gradient_clip": 1.0,
203221
}
204222

205223
network_module = "ctc.conformer_1023.i6modelsV1_VGG4LayerActFrontendV1_v6"
@@ -224,3 +242,8 @@ def tune_and_evaluate_helper(
224242
lm_scales=[1.6, 1.8, 2.0],
225243
prior_scales=[0.2, 0.3, 0.4],
226244
)
245+
246+
greedy_decoder_config = GreedyDecoderConfig(
247+
returnn_vocab=label_datastream_bpe5000.vocab,
248+
)
249+
greedy_search_helper(training_name=training_name, asr_model=asr_model, decoder_config=greedy_decoder_config)

example_setups/librispeech/ctc_rnnt_standalone_2024/experiments/ctc_phon/baseline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def tune_and_evaluate_helper(
195195
"max_seq_length": {"audio_features": 35 * 16000},
196196
"accum_grad_multiple_step": 1,
197197
"torch_amp_options": {"dtype": "bfloat16"},
198+
"gradient_clip": 1.0,
198199
}
199200

200201
network_module = "ctc.conformer_1023.i6modelsV1_VGG4LayerActFrontendV1_v6"
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""
2+
Greedy CTC decoder without any extras
3+
4+
v3: add config objects
5+
"""
6+
from dataclasses import dataclass
7+
import time
8+
import torch
9+
10+
11+
@dataclass
12+
class DecoderConfig:
13+
returnn_vocab: str
14+
15+
16+
@dataclass
17+
class ExtraConfig:
18+
# used for RTF logging
19+
print_rtf: bool = True
20+
sample_rate: int = 16000
21+
22+
# Hypothesis logging
23+
print_hypothesis: bool = True
24+
25+
26+
def forward_init_hook(run_ctx, **kwargs):
27+
# we are storing durations, but call it output.hdf to match
28+
# the default output of the ReturnnForwardJob
29+
config = DecoderConfig(**kwargs["config"])
30+
extra_config_dict = kwargs.get("extra_config", {})
31+
extra_config = ExtraConfig(**extra_config_dict)
32+
33+
run_ctx.recognition_file = open("search_out.py", "wt")
34+
run_ctx.recognition_file.write("{\n")
35+
36+
from returnn.datasets.util.vocabulary import Vocabulary
37+
38+
vocab = Vocabulary.create_vocab(vocab_file=config.returnn_vocab, unknown_label=None)
39+
run_ctx.labels = vocab.labels
40+
41+
run_ctx.print_rtf = extra_config.print_rtf
42+
if run_ctx.print_rtf:
43+
run_ctx.running_audio_len_s = 0
44+
run_ctx.total_time = 0
45+
46+
run_ctx.print_hypothesis = extra_config.print_hypothesis
47+
48+
49+
def forward_finish_hook(run_ctx, **kwargs):
50+
run_ctx.recognition_file.write("}\n")
51+
run_ctx.recognition_file.close()
52+
53+
print("Total-time: %.2f, Batch-RTF: %.3f" % (run_ctx.total_time, run_ctx.total_time / run_ctx.running_audio_len_s))
54+
55+
56+
def forward_step(*, model, data, run_ctx, **kwargs):
57+
raw_audio = data["raw_audio"] # [B, T', F]
58+
raw_audio_len = data["raw_audio:size1"] # [B]
59+
60+
audio_len_batch = torch.sum(raw_audio_len).detach().cpu().numpy() / 16000
61+
62+
if run_ctx.print_rtf:
63+
run_ctx.running_audio_len_s += audio_len_batch
64+
am_start = time.time()
65+
66+
logprobs, audio_features_len = model(
67+
raw_audio=raw_audio,
68+
raw_audio_len=raw_audio_len,
69+
)
70+
batch_indices = []
71+
for lp, l in zip(logprobs, audio_features_len):
72+
batch_indices.append(torch.unique_consecutive(torch.argmax(lp[:l], dim=-1), dim=0).detach().cpu().numpy())
73+
74+
if run_ctx.print_rtf:
75+
am_time = time.time() - am_start
76+
run_ctx.total_time += am_time
77+
print("Batch-time: %.2f, Batch-RTF: %.3f" % (am_time, am_time / audio_len_batch))
78+
79+
tags = data["seq_tag"]
80+
81+
for indices, tag in zip(batch_indices, tags):
82+
sequence = [run_ctx.labels[idx] for idx in indices if idx < len(run_ctx.labels)]
83+
sequence = [s for s in sequence if (not s.startswith("<") and not s.startswith("["))]
84+
text = " ".join(sequence).replace("@@ ", "")
85+
if run_ctx.print_hypothesis:
86+
print(text)
87+
run_ctx.recognition_file.write("%s: %s,\n" % (repr(tag), repr(text)))

users/berger/args/experiments/ctc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_ctc_recog_step_args(num_classes: int, reduction_factor: int = 4, **kwarg
6767
"mem_rqmt": 16,
6868
},
6969
"rtf": 20,
70-
"mem": 4,
70+
"mem": 8,
7171
}
7272

7373
return recursive_update(default_args, kwargs)

users/berger/args/experiments/transducer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def get_transducer_recog_step_args(
6868
"mem_rqmt": 16,
6969
},
7070
"rtf": 50,
71-
"mem": 4,
71+
"mem": 8,
7272
}
7373

7474
return recursive_update(default_args, kwargs)

users/berger/args/jobs/rasr_init_args.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def get_feature_extraction_args_16kHz(
9191
gt_args: Optional[Dict] = None,
9292
) -> Dict:
9393
mfcc_filter_width = features.filter_width_from_channels(channels=20, f_max=8000) # = 16000 / 2
94+
filterbank_filter_width = features.filter_width_from_channels(channels=80, f_max=8000) # = 16000 / 2
9495

9596
if mfcc_cepstrum_options is None:
9697
mfcc_cepstrum_options = {
@@ -142,6 +143,30 @@ def get_feature_extraction_args_16kHz(
142143
"normalization_options": {},
143144
}
144145
},
146+
"filterbank": {
147+
"filterbank_options": {
148+
"warping_function": "mel",
149+
"filter_width": filterbank_filter_width,
150+
"normalize": False,
151+
"normalization_options": {},
152+
"without_samples": False,
153+
"samples_options": {
154+
"audio_format": "wav",
155+
# "scale_input": 2**-15,
156+
"dc_detection": dc_detection,
157+
},
158+
"fft_options": {
159+
"preemphasis": 0.97,
160+
"window_type": "hanning",
161+
"window_shift": 0.01,
162+
"window_length": 0.025,
163+
},
164+
"apply_log": True,
165+
"add_epsilon": True,
166+
"add_features_output": True,
167+
# "warp_differential_unit": False,
168+
},
169+
},
145170
"energy": {
146171
"energy_options": {
147172
"without_samples": False,

users/berger/args/returnn/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get_base_config(backend: Backend) -> Dict[str, Any]:
2121
elif backend == Backend.PYTORCH:
2222
result["backend"] = "torch"
2323
result["use_lovely_tensors"] = True
24+
# result["torch_amp"] = {"dtype": "bfloat16"}
2425
else:
2526
raise NotImplementedError
2627
return result

users/berger/args/returnn/learning_rates.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class LearningRateSchedules(Enum):
1010
NewbobAbs = auto()
1111
OCLR = auto()
1212
OCLR_STEP = auto()
13+
OCLR_STEP_TORCH = auto()
1314
CONST_DECAY = auto()
1415
CONST_DECAY_STEP = auto()
1516

@@ -38,6 +39,8 @@ def get_learning_rate_config(
3839
config.update(get_oclr_config(**kwargs))
3940
elif schedule == LearningRateSchedules.OCLR_STEP:
4041
extra_python.append(get_oclr_function(**kwargs))
42+
elif schedule == LearningRateSchedules.OCLR_STEP_TORCH:
43+
extra_python.append(get_oclr_function_torch(**kwargs))
4144
elif schedule == LearningRateSchedules.CONST_DECAY:
4245
config.update(get_const_decay_config(**kwargs))
4346
elif schedule == LearningRateSchedules.CONST_DECAY_STEP:
@@ -184,6 +187,58 @@ def get_oclr_function(
184187
)
185188

186189

190+
def get_oclr_function_torch(
191+
num_epochs: int,
192+
n_steps_per_epoch: int,
193+
peak_lr: float = 1e-03,
194+
inc_epochs: Optional[int] = None,
195+
dec_epochs: Optional[int] = None,
196+
initial_lr: Optional[float] = None,
197+
decayed_lr: Optional[float] = None,
198+
final_lr: Optional[float] = None,
199+
**kwargs,
200+
) -> str:
201+
initial_lr = initial_lr or peak_lr / 10
202+
decayed_lr = decayed_lr or initial_lr
203+
final_lr = final_lr or initial_lr / 5
204+
inc_epochs = inc_epochs or (num_epochs * 9) // 20
205+
dec_epochs = dec_epochs or inc_epochs
206+
207+
return dedent(
208+
f"""def dynamic_learning_rate(*, global_train_step: int, **_):
209+
# Increase linearly from initial_lr to peak_lr over the first inc_epoch epochs
210+
# Decrease linearly from peak_lr to decayed_lr over the next dec_epoch epochs
211+
# Decrease linearly from decayed_lr to final_lr over the remaining epochs
212+
initial_lr = {initial_lr}
213+
peak_lr = {peak_lr}
214+
decayed_lr = {decayed_lr}
215+
final_lr = {final_lr}
216+
inc_epochs = {inc_epochs}
217+
dec_epochs = {dec_epochs}
218+
total_epochs = {num_epochs}
219+
n_steps_per_epoch = {n_steps_per_epoch}
220+
221+
# -- derived -- #
222+
steps_increase = inc_epochs * n_steps_per_epoch
223+
steps_decay = dec_epochs * n_steps_per_epoch
224+
steps_final = (total_epochs - inc_epochs - dec_epochs) * n_steps_per_epoch
225+
226+
step_size_increase = (peak_lr - initial_lr) / steps_increase
227+
step_size_decay = (peak_lr - decayed_lr) / steps_decay
228+
step_size_final = (decayed_lr - final_lr) / steps_final
229+
230+
if global_train_step <= steps_increase:
231+
return initial_lr + step_size_increase * global_train_step
232+
if global_train_step <= steps_increase + steps_decay:
233+
return peak_lr - step_size_decay * (global_train_step - steps_increase)
234+
235+
return max(
236+
decayed_lr - step_size_final * (global_train_step - steps_increase - steps_decay),
237+
final_lr
238+
)"""
239+
)
240+
241+
187242
def get_const_decay_config(
188243
num_epochs: int,
189244
const_lr: float = 1e-03,

users/berger/args/returnn/regularization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def get_chunking_config(
1313

1414
if isinstance(chunking_factors, list):
1515
chunking_factors = {key: 1 for key in chunking_factors}
16-
assert isinstance(chunking_factors, Dict)
16+
assert isinstance(chunking_factors, dict)
1717
return {
1818
"chunking": (
1919
{key: base_chunk_size // factor for key, factor in chunking_factors.items()},

users/berger/configs/librispeech/20230210_baselines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .config_02c_transducer_rasr_features_wei_lex import py as py_02c
1616
from .config_02e_transducer_rasr_features_tinaconf import py as py_02e
1717
from .config_02e_transducer_rasr_features_tinaconf_rtf import py as py_02e_rtf
18+
from .config_02f_transducer_rasr_features_am_scales import py as py_02f
1819
from .config_03a_transducer_fullsum_raw_samples import py as py_03a
1920
from .config_03b_transducer_fullsum_rasr_features import py as py_03b
2021
from .config_03c_transducer_fullsum_rasr_features_wei_lex import py as py_03c
@@ -37,6 +38,7 @@ def main() -> SummaryReport:
3738
sub_reports.append(copy.deepcopy(py_02c()[0]))
3839
sub_reports.append(copy.deepcopy(py_02e()))
3940
sub_reports.append(copy.deepcopy(py_02e_rtf()))
41+
sub_reports.append(copy.deepcopy(py_02f()))
4042
sub_reports.append(copy.deepcopy(py_03a()))
4143
sub_reports.append(copy.deepcopy(py_03b()))
4244
sub_reports.append(copy.deepcopy(py_03c()))

0 commit comments

Comments
 (0)