Skip to content

Commit 18d3f1a

Browse files
Implements two front-ends for acoustic encoders (#17)
* add two frontends * cleaner docstring * fix padding * extend, fixes, tests * consistency to other code * unify naming * rename vars * naming * rm static method wrapper * update doc * rename frontend classes * mv to new dir * more * consistency * rm self * fix error type * black * add structure doc to class * strides -> stride * update docs * more configurable * improve front end forward func * more * fixes * fixes * more * extend doc * more * more * add protocol * start implementing seq mask * extend tests and more * add linear layer * Apply suggestions by albertz Co-authored-by: Albert Zeyer <[email protected]> * black * add get func for int|tuple[int, ...] * add pool masking of seq mask * typing * add assert * mv list call * add masking reference * fix masks * fixes * switch mask 0<>1 * follow vgg def more closely * fix * fix * fixes * add print statement to test * fixes to structure * doc * fixes * split tests into 2 files * wip * mv vgg pool variant -> new PR * cleanup * black * fix typing * cleanup * handle masking None * check * missing not * black * mv cast into _get_mask * black * update doc * add optional flag * black * rm protocol * rewrite tests * rm optional mask output * add doc: explain mask update locations * funcs non private * fix * add seq masking for all steps * black * updates * updates * cleanup and fix type hint * cleanup * updates and fixes * doc * remove implicit assumption for dim update * smaller namespace * cleanup * force usage of tuple, no int * missing conversion int to tuple * better type hinting * add padding var * turn into kwargs * more * fix padding kernel check for mask pool --------- Co-authored-by: Albert Zeyer <[email protected]>
1 parent 2e76cd9 commit 18d3f1a

File tree

5 files changed

+629
-0
lines changed

5 files changed

+629
-0
lines changed

i6_models/parts/frontend/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Different front-ends for acoustic encoders
2+
3+
### Contributing
4+
5+
If you want to add your own front-end:
6+
7+
- Normally two classes are required. A config class and a model class
8+
- `Config` class inherits from `ModelConfiguration`
9+
- `Model` class inherits from `nn.Module` from `torch`
10+
- `forward(tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]`
11+
- `sequence_mask` is a boolean tensor where `True` means is inside the sequence and `False` is masked.
12+
- Please add tests

i6_models/parts/frontend/__init__.py

Whitespace-only changes.

i6_models/parts/frontend/common.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Tuple, Union
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import functional
6+
7+
8+
def get_same_padding(input_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
9+
"""
10+
get padding in order to not reduce the time dimension
11+
12+
:param input_size:
13+
:return:
14+
"""
15+
if isinstance(input_size, int):
16+
return (input_size - 1) // 2
17+
elif isinstance(input_size, tuple):
18+
return tuple((s - 1) // 2 for s in input_size)
19+
else:
20+
raise TypeError(f"unexpected size type {type(input_size)}")
21+
22+
23+
def mask_pool(seq_mask: torch.Tensor, *, kernel_size: int, stride: int, padding: int) -> torch.Tensor:
24+
"""
25+
apply strides to the masking
26+
27+
:param seq_mask: [B,T]
28+
:param kernel_size:
29+
:param stride:
30+
:param padding:
31+
:return: [B,T'] using maxpool
32+
"""
33+
if stride == 1 and 2 * padding == kernel_size - 1:
34+
return seq_mask
35+
36+
seq_mask = seq_mask.float()
37+
seq_mask = torch.unsqueeze(seq_mask, 1) # [B,1,T]
38+
seq_mask = nn.functional.max_pool1d(seq_mask, kernel_size, stride, padding) # [B,1,T']
39+
seq_mask = torch.squeeze(seq_mask, 1) # [B,T']
40+
seq_mask = seq_mask.bool()
41+
return seq_mask
42+
43+
44+
def calculate_output_dim(in_dim: int, *, filter_size: int, stride: int, padding: int) -> int:
45+
def ceildiv(a: int, b: int):
46+
return -(-a // b)
47+
48+
return ceildiv(in_dim + 2 * padding - (filter_size - 1) * 1, stride)
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from __future__ import annotations
2+
3+
__all__ = [
4+
"VGG4LayerActFrontendV1",
5+
"VGG4LayerActFrontendV1Config",
6+
]
7+
8+
from dataclasses import dataclass
9+
from typing import Callable, Optional, Tuple, Union
10+
11+
import torch
12+
from torch import nn
13+
14+
from i6_models.config import ModelConfiguration
15+
16+
from .common import get_same_padding, mask_pool, calculate_output_dim
17+
18+
19+
@dataclass
20+
class VGG4LayerActFrontendV1Config(ModelConfiguration):
21+
"""
22+
Attributes:
23+
in_features: number of input features to module
24+
conv1_channels: number of channels for first conv layer
25+
conv2_channels: number of channels for second conv layer
26+
conv3_channels: number of channels for third conv layer
27+
conv4_channels: number of channels for fourth conv layer
28+
conv_kernel_size: kernel size of conv layers
29+
conv_padding: padding for the convolution
30+
pool1_kernel_size: kernel size of first pooling layer
31+
pool1_stride: stride of first pooling layer
32+
pool1_padding: padding for first pooling layer
33+
pool2_kernel_size: kernel size of second pooling layer
34+
pool2_stride: stride of second pooling layer
35+
pool2_padding: padding for second pooling layer
36+
activation: activation function at the end
37+
out_features: output size of the final linear layer
38+
"""
39+
40+
in_features: int
41+
conv1_channels: int
42+
conv2_channels: int
43+
conv3_channels: int
44+
conv4_channels: int
45+
conv_kernel_size: Tuple[int, int]
46+
conv_padding: Optional[Tuple[int, int]]
47+
pool1_kernel_size: Tuple[int, int]
48+
pool1_stride: Optional[Tuple[int, int]]
49+
pool1_padding: Optional[Tuple[int, int]]
50+
pool2_kernel_size: Tuple[int, int]
51+
pool2_stride: Optional[Tuple[int, int]]
52+
pool2_padding: Optional[Tuple[int, int]]
53+
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
54+
out_features: int
55+
56+
def check_valid(self):
57+
if isinstance(self.conv_kernel_size, int):
58+
assert self.conv_kernel_size % 2 == 1, "ConformerVGGFrontendV1 only supports odd kernel sizes"
59+
if isinstance(self.pool1_kernel_size, int):
60+
assert self.pool1_kernel_size % 2 == 1, "ConformerVGGFrontendV1 only supports odd kernel sizes"
61+
if isinstance(self.pool2_kernel_size, int):
62+
assert self.pool2_kernel_size % 2 == 1, "ConformerVGGFrontendV1 only supports odd kernel sizes"
63+
64+
def __post__init__(self):
65+
super().__post_init__()
66+
self.check_valid()
67+
68+
69+
class VGG4LayerActFrontendV1(nn.Module):
70+
"""
71+
Convolutional Front-End
72+
73+
The frond-end utilizes convolutional and pooling layers, as well as activation functions
74+
to transform a feature vector, typically Log-Mel or Gammatone for audio, into an intermediate
75+
representation.
76+
77+
Structure of the front-end:
78+
- Conv
79+
- Conv
80+
- Activation
81+
- Pool
82+
- Conv
83+
- Conv
84+
- Activation
85+
- Pool
86+
87+
Uses explicit padding for ONNX exportability, see:
88+
https://github.com/pytorch/pytorch/issues/68880
89+
"""
90+
91+
def __init__(self, model_cfg: VGG4LayerActFrontendV1Config):
92+
"""
93+
:param model_cfg: model configuration for this module
94+
"""
95+
super().__init__()
96+
97+
model_cfg.check_valid()
98+
99+
self.cfg = model_cfg
100+
101+
conv_padding = (
102+
model_cfg.conv_padding
103+
if model_cfg.conv_padding is not None
104+
else get_same_padding(model_cfg.conv_kernel_size)
105+
)
106+
pool1_padding = model_cfg.pool1_padding if model_cfg.pool1_padding is not None else (0, 0)
107+
pool2_padding = model_cfg.pool2_padding if model_cfg.pool2_padding is not None else (0, 0)
108+
109+
self.conv1 = nn.Conv2d(
110+
in_channels=1,
111+
out_channels=model_cfg.conv1_channels,
112+
kernel_size=model_cfg.conv_kernel_size,
113+
padding=conv_padding,
114+
)
115+
self.conv2 = nn.Conv2d(
116+
in_channels=model_cfg.conv1_channels,
117+
out_channels=model_cfg.conv2_channels,
118+
kernel_size=model_cfg.conv_kernel_size,
119+
padding=conv_padding,
120+
)
121+
self.pool1 = nn.MaxPool2d(
122+
kernel_size=model_cfg.pool1_kernel_size,
123+
stride=model_cfg.pool1_stride,
124+
padding=pool1_padding,
125+
)
126+
self.conv3 = nn.Conv2d(
127+
in_channels=model_cfg.conv2_channels,
128+
out_channels=model_cfg.conv3_channels,
129+
kernel_size=model_cfg.conv_kernel_size,
130+
padding=conv_padding,
131+
)
132+
self.conv4 = nn.Conv2d(
133+
in_channels=model_cfg.conv3_channels,
134+
out_channels=model_cfg.conv4_channels,
135+
kernel_size=model_cfg.conv_kernel_size,
136+
padding=conv_padding,
137+
)
138+
self.pool2 = nn.MaxPool2d(
139+
kernel_size=model_cfg.pool2_kernel_size,
140+
stride=model_cfg.pool2_stride,
141+
padding=pool2_padding,
142+
)
143+
self.activation = model_cfg.activation
144+
self.linear = nn.Linear(
145+
in_features=self._calculate_dim(),
146+
out_features=model_cfg.out_features,
147+
bias=True,
148+
)
149+
150+
def forward(self, tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
151+
"""
152+
T might be reduced to T' or T'' depending on stride of the layers
153+
154+
stride is only allowed for the pool1 and pool2 operation.
155+
other ops do not have stride configurable -> no update of mask sequence required but added anyway
156+
157+
:param tensor: input tensor of shape [B,T,F]
158+
:param sequence_mask: the sequence mask for the tensor
159+
:return: torch.Tensor of shape [B,T",F'] and the shape of the sequence mask
160+
"""
161+
assert tensor.shape[-1] == self.cfg.in_features
162+
# and add a dim
163+
tensor = tensor[:, None, :, :] # [B,C=1,T,F]
164+
165+
tensor = self.conv1(tensor)
166+
sequence_mask = mask_pool(
167+
seq_mask=sequence_mask,
168+
kernel_size=self.conv1.kernel_size[0],
169+
stride=self.conv1.stride[0],
170+
padding=self.conv1.padding[0],
171+
)
172+
173+
tensor = self.conv2(tensor)
174+
sequence_mask = mask_pool(
175+
sequence_mask,
176+
kernel_size=self.conv2.kernel_size[0],
177+
stride=self.conv2.stride[0],
178+
padding=self.conv2.padding[0],
179+
)
180+
181+
tensor = self.activation(tensor)
182+
tensor = self.pool1(tensor) # [B,C,T',F']
183+
sequence_mask = mask_pool(
184+
sequence_mask,
185+
kernel_size=self.pool1.kernel_size[0],
186+
stride=self.pool1.stride[0],
187+
padding=self.pool1.padding[0],
188+
)
189+
190+
tensor = self.conv3(tensor)
191+
sequence_mask = mask_pool(
192+
sequence_mask,
193+
kernel_size=self.conv3.kernel_size[0],
194+
stride=self.conv3.stride[0],
195+
padding=self.conv3.padding[0],
196+
)
197+
198+
tensor = self.conv4(tensor)
199+
sequence_mask = mask_pool(
200+
sequence_mask,
201+
kernel_size=self.conv4.kernel_size[0],
202+
stride=self.conv4.stride[0],
203+
padding=self.conv4.padding[0],
204+
)
205+
206+
tensor = self.activation(tensor)
207+
tensor = self.pool2(tensor) # [B,C,T",F"]
208+
sequence_mask = mask_pool(
209+
sequence_mask,
210+
kernel_size=self.pool2.kernel_size[0],
211+
stride=self.pool2.stride[0],
212+
padding=self.pool2.padding[0],
213+
)
214+
215+
tensor = torch.transpose(tensor, 1, 2) # transpose to [B,T",C,F"]
216+
tensor = torch.flatten(tensor, start_dim=2, end_dim=-1) # [B,T",C*F"]
217+
218+
tensor = self.linear(tensor)
219+
220+
return tensor, sequence_mask
221+
222+
def _calculate_dim(self) -> int:
223+
# conv1
224+
out_dim = calculate_output_dim(
225+
in_dim=self.cfg.in_features,
226+
filter_size=self.conv1.kernel_size[1],
227+
stride=self.conv1.stride[1],
228+
padding=self.conv1.padding[1],
229+
)
230+
# conv2
231+
out_dim = calculate_output_dim(
232+
in_dim=out_dim,
233+
filter_size=self.conv2.kernel_size[1],
234+
stride=self.conv2.stride[1],
235+
padding=self.conv2.padding[1],
236+
)
237+
# pool1
238+
out_dim = calculate_output_dim(
239+
in_dim=out_dim,
240+
filter_size=self.pool1.kernel_size[1],
241+
stride=self.pool1.stride[1],
242+
padding=self.pool1.padding[1],
243+
)
244+
# conv3
245+
out_dim = calculate_output_dim(
246+
in_dim=out_dim,
247+
filter_size=self.conv3.kernel_size[1],
248+
stride=self.conv3.stride[1],
249+
padding=self.conv3.padding[1],
250+
)
251+
# conv4
252+
out_dim = calculate_output_dim(
253+
in_dim=out_dim,
254+
filter_size=self.conv4.kernel_size[1],
255+
stride=self.conv4.stride[1],
256+
padding=self.conv4.padding[1],
257+
)
258+
# pool2
259+
out_dim = calculate_output_dim(
260+
in_dim=out_dim,
261+
filter_size=self.pool2.kernel_size[1],
262+
stride=self.pool2.stride[1],
263+
padding=self.pool2.padding[1],
264+
)
265+
out_dim *= self.conv4.out_channels
266+
return out_dim

0 commit comments

Comments
 (0)