Skip to content

Commit 286a2ab

Browse files
committed
add jaist synthetic data project
1 parent 8c294b3 commit 286a2ab

File tree

112 files changed

+45533
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+45533
-0
lines changed

users/rossenbach/experiments/jaist_project/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
PACKAGE = __package__
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import copy
2+
import numpy as np
3+
from sisyphus import tk
4+
from typing import Any, Dict, Optional
5+
6+
from i6_core.returnn.config import ReturnnConfig, CodeWrapper
7+
8+
from i6_experiments.common.setups.returnn_pytorch.serialization import (
9+
Collection as TorchCollection,
10+
)
11+
from i6_experiments.common.setups.serialization import Import
12+
from .data.common import TrainingDatasets
13+
from .serializer import get_pytorch_serializer_v3, PACKAGE
14+
15+
from i6_experiments.users.rossenbach.common_setups.returnn.datasets import GenericDataset
16+
17+
18+
def get_training_config(
19+
training_datasets: TrainingDatasets,
20+
network_module: str,
21+
net_args: Dict[str, Any],
22+
config: Dict[str, Any],
23+
debug: bool = False,
24+
use_custom_engine: bool = False,
25+
use_speed_perturbation: bool = False,
26+
post_config: Optional[Dict[str, Any]] = None,
27+
) -> ReturnnConfig:
28+
"""
29+
:param training_datasets: datasets for training
30+
:param network_module: path to the pytorch config file containing Model
31+
:param net_args: extra arguments for the model
32+
:param config:
33+
:param debug: run training in debug mode (linking from recipe instead of copy)
34+
"""
35+
36+
# changing these does not change the hash
37+
base_post_config = {
38+
"cleanup_old_models": True,
39+
"stop_on_nonfinite_train_score": True, # this might break now with True
40+
"num_workers_per_gpu": 2,
41+
"backend": "torch"
42+
}
43+
44+
base_config = {
45+
#############
46+
"train": copy.deepcopy(training_datasets.train.as_returnn_opts()),
47+
"dev": training_datasets.cv.as_returnn_opts(),
48+
"eval_datasets": {
49+
"devtrain": training_datasets.devtrain.as_returnn_opts()
50+
}
51+
}
52+
config = {**base_config, **copy.deepcopy(config)}
53+
post_config = {**base_post_config, **copy.deepcopy(post_config or {})}
54+
55+
serializer = get_pytorch_serializer_v3(
56+
network_module=network_module,
57+
net_args=net_args,
58+
debug=debug,
59+
use_custom_engine=use_custom_engine
60+
)
61+
python_prolog = None
62+
63+
# TODO: maybe make nice
64+
if use_speed_perturbation:
65+
prolog_serializer = TorchCollection(
66+
serializer_objects=[Import(
67+
code_object_path=PACKAGE + ".extra_code.speed_perturbation.legacy_speed_perturbation",
68+
unhashed_package_root=PACKAGE
69+
)]
70+
)
71+
python_prolog = [prolog_serializer]
72+
config["train"]["datasets"]["zip_dataset"]["audio"]["pre_process"] = CodeWrapper("legacy_speed_perturbation")
73+
74+
returnn_config = ReturnnConfig(
75+
config=config, post_config=post_config, python_prolog=python_prolog, python_epilog=[serializer]
76+
)
77+
return returnn_config
78+
79+
80+
def get_prior_config(
81+
training_datasets: TrainingDatasets,
82+
network_module: str,
83+
net_args: Dict[str, Any],
84+
config: Dict[str, Any],
85+
debug: bool = False,
86+
use_custom_engine=False,
87+
**kwargs,
88+
):
89+
"""
90+
Returns the RETURNN config serialized by :class:`ReturnnCommonSerializer` in returnn_common for the ctc_aligner
91+
:param returnn_common_root: returnn_common version to be used, usually output of CloneGitRepositoryJob
92+
:param training_datasets: datasets for training
93+
:param kwargs: arguments to be passed to the network construction
94+
:return: RETURNN training config
95+
"""
96+
97+
# changing these does not change the hash
98+
post_config = {
99+
}
100+
101+
base_config = {
102+
#############
103+
"batch_size": 2000 * 16000,
104+
"max_seqs": 240,
105+
#############
106+
"forward": training_datasets.prior.as_returnn_opts()
107+
108+
}
109+
config = {**base_config, **copy.deepcopy(config)}
110+
post_config["backend"] = "torch"
111+
112+
serializer = get_pytorch_serializer_v3(
113+
network_module=network_module,
114+
net_args=net_args,
115+
debug=debug,
116+
use_custom_engine=use_custom_engine,
117+
prior=True,
118+
)
119+
returnn_config = ReturnnConfig(
120+
config=config, post_config=post_config, python_epilog=[serializer]
121+
)
122+
return returnn_config
123+
124+
125+
126+
def get_forward_config(
127+
network_module: str,
128+
net_args: Dict[str, Any],
129+
decoder: [str],
130+
decoder_args: Dict[str, Any],
131+
config: Dict[str, Any],
132+
debug: bool = False,
133+
use_custom_engine=False,
134+
**kwargs,
135+
):
136+
"""
137+
Returns the RETURNN config serialized by :class:`ReturnnCommonSerializer` in returnn_common for the ctc_aligner
138+
:param returnn_common_root: returnn_common version to be used, usually output of CloneGitRepositoryJob
139+
:param training_datasets: datasets for training
140+
:param kwargs: arguments to be passed to the network construction
141+
:return: RETURNN training config
142+
"""
143+
144+
# changing these does not change the hash
145+
post_config = {
146+
}
147+
148+
base_config = {
149+
#############
150+
"batch_size": 1000 * 16000,
151+
"max_seqs": 240,
152+
#############
153+
# dataset is added later in the pipeline during search_single
154+
}
155+
config = {**base_config, **copy.deepcopy(config)}
156+
post_config["backend"] = "torch"
157+
158+
serializer = get_pytorch_serializer_v3(
159+
network_module=network_module,
160+
net_args=net_args,
161+
debug=debug,
162+
use_custom_engine=use_custom_engine,
163+
decoder=decoder,
164+
decoder_args=decoder_args,
165+
)
166+
returnn_config = ReturnnConfig(
167+
config=config, post_config=post_config, python_epilog=[serializer]
168+
)
169+
return returnn_config

users/rossenbach/experiments/jaist_project/standalone_2024/ctc_bpe/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)