Skip to content

Commit 3ec9305

Browse files
authored
Explicit padding for ConformerConvolutionV1 (#16)
1 parent 04fafab commit 3ec9305

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

i6_models/parts/conformer/convolution.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,35 @@ class ConformerConvolutionV1Config(ModelConfiguration):
2323
norm: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
2424
"""normalization layer with input of shape [N,C,T]"""
2525

26+
def check_valid(self):
27+
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"
28+
29+
def __post_init__(self):
30+
super().__post_init__()
31+
self.check_valid()
32+
2633

2734
class ConformerConvolutionV1(nn.Module):
2835
"""
2936
Conformer convolution module.
3037
see also: https://github.com/espnet/espnet/blob/713e784c0815ebba2053131307db5f00af5159ea/espnet/nets/pytorch_backend/conformer/convolution.py#L13
38+
39+
Uses explicit padding for ONNX exportability, see:
40+
https://github.com/pytorch/pytorch/issues/68880
3141
"""
3242

3343
def __init__(self, model_cfg: ConformerConvolutionV1Config):
3444
"""
3545
:param model_cfg: model configuration for this module
3646
"""
3747
super().__init__()
38-
48+
model_cfg.check_valid()
3949
self.pointwise_conv1 = nn.Linear(in_features=model_cfg.channels, out_features=2 * model_cfg.channels)
4050
self.depthwise_conv = nn.Conv1d(
4151
in_channels=model_cfg.channels,
4252
out_channels=model_cfg.channels,
4353
kernel_size=model_cfg.kernel_size,
44-
padding="same",
54+
padding=(model_cfg.kernel_size - 1) // 2,
4555
groups=model_cfg.channels,
4656
)
4757
self.pointwise_conv2 = nn.Linear(in_features=model_cfg.channels, out_features=model_cfg.channels)

tests/test_conformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0
3232
assert get_output_shape(10, 1, 50) == (10, 1, 50) # time dim 1
3333
assert get_output_shape(10, 10, 20, dropout=0.0) == (10, 10, 20) # dropout 0
3434
assert get_output_shape(10, 10, 20, kernel_size=3) == (10, 10, 20) # odd kernel size
35-
assert get_output_shape(10, 10, 20, kernel_size=32) == (10, 10, 20) # even kernel size
3635

3736

3837
def test_ConformerPositionwiseFeedForwardV1():

0 commit comments

Comments
 (0)