Skip to content

Commit 147c00c

Browse files
mmz33JackTemakialbertzAtticus1806Eugen Beck
authored
Implement conformer convolution part (#4)
Co-authored-by: Nick Rossenbach <[email protected]> Co-authored-by: Albert Zeyer <[email protected]> Co-authored-by: Benedikt Hilmes <[email protected]> Co-authored-by: Eugen Beck <[email protected]>
1 parent 0f20821 commit 147c00c

File tree

4 files changed

+135
-1
lines changed

4 files changed

+135
-1
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
import torch
6+
from torch import nn
7+
from i6_models.config import ModelConfiguration
8+
from typing import Callable, Union
9+
10+
11+
@dataclass
12+
class ConformerConvolutionV1Config(ModelConfiguration):
13+
channels: int
14+
"""number of channels for conv layers"""
15+
kernel_size: int
16+
"""kernel size of conv layers"""
17+
dropout: float
18+
"""dropout probability"""
19+
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
20+
"""activation function applied after norm"""
21+
norm: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
22+
"""normalization layer with input of shape [N,C,T]"""
23+
24+
25+
class ConformerConvolutionV1(nn.Module):
26+
"""
27+
Conformer convolution module.
28+
see also: https://github.com/espnet/espnet/blob/713e784c0815ebba2053131307db5f00af5159ea/espnet/nets/pytorch_backend/conformer/convolution.py#L13
29+
"""
30+
31+
def __init__(self, model_cfg: ConformerConvolutionV1Config):
32+
"""
33+
:param model_cfg: model configuration for this module
34+
"""
35+
super().__init__()
36+
37+
self.pointwise_conv1 = nn.Linear(in_features=model_cfg.channels, out_features=2 * model_cfg.channels)
38+
self.depthwise_conv = nn.Conv1d(
39+
in_channels=model_cfg.channels,
40+
out_channels=model_cfg.channels,
41+
kernel_size=model_cfg.kernel_size,
42+
padding="same",
43+
groups=model_cfg.channels,
44+
)
45+
self.pointwise_conv2 = nn.Linear(in_features=model_cfg.channels, out_features=model_cfg.channels)
46+
self.layer_norm = nn.LayerNorm(model_cfg.channels)
47+
self.norm = model_cfg.norm
48+
self.dropout = nn.Dropout(model_cfg.dropout)
49+
self.activation = model_cfg.activation
50+
51+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
52+
"""
53+
:param tensor: input tensor of shape [B,T,F]
54+
:return: torch.Tensor of shape [B,T,F]
55+
"""
56+
tensor = self.layer_norm(tensor)
57+
tensor = self.pointwise_conv1(tensor) # [B,T,2F]
58+
tensor = nn.functional.glu(tensor, dim=-1) # [B,T,F]
59+
60+
# conv layers expect shape [B,F,T] so we have to transpose here
61+
tensor = tensor.transpose(1, 2) # [B,F,T]
62+
tensor = self.depthwise_conv(tensor)
63+
64+
tensor = self.norm(tensor)
65+
tensor = tensor.transpose(1, 2) # transpose back to [B,T,F]
66+
67+
tensor = self.activation(tensor)
68+
tensor = self.pointwise_conv2(tensor)
69+
70+
return self.dropout(tensor)

i6_models/parts/conformer/norm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class LayerNormNC(nn.LayerNorm):
6+
"""
7+
LayerNorm that accepts [N,C,*] tensors and normalizes over C (channels) dimension.
8+
see here: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
9+
"""
10+
11+
def __init__(self, channels: int):
12+
"""
13+
:param channels: number of channels for normalization
14+
"""
15+
super().__init__(channels)
16+
17+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
18+
"""
19+
:param tensor: input tensor with shape [N,C,*]
20+
:return: normalized tensor with shape [N,C,*]
21+
"""
22+
return super().forward(tensor.transpose(1, -1)).transpose(1, -1)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
typeguard
2-
torch
2+
torch

tests/test_conformer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import torch
44
from torch import nn
55

6+
from i6_models.parts.conformer.convolution import ConformerConvolutionV1, ConformerConvolutionV1Config
67
from i6_models.parts.conformer.feedforward import (
78
ConformerPositionwiseFeedForwardV1,
89
ConformerPositionwiseFeedForwardV1Config,
910
)
11+
from i6_models.parts.conformer.norm import LayerNormNC
1012

1113

1214
def test_ConformerPositionwiseFeedForwardV1():
@@ -22,3 +24,43 @@ def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation):
2224
):
2325
input_shape = (10, 100, input_dim)
2426
assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape
27+
28+
29+
def test_conformer_convolution_output_shape():
30+
def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0.1, activation=nn.functional.silu):
31+
x = torch.randn(batch, time, features)
32+
if norm is None:
33+
norm = nn.BatchNorm1d(features)
34+
cfg = ConformerConvolutionV1Config(
35+
channels=features, kernel_size=kernel_size, dropout=dropout, activation=activation, norm=norm
36+
)
37+
conformer_conv_part = ConformerConvolutionV1(cfg)
38+
y = conformer_conv_part(x)
39+
return y.shape
40+
41+
assert get_output_shape(10, 50, 250) == (10, 50, 250)
42+
assert get_output_shape(10, 50, 250, activation=nn.functional.relu) == (10, 50, 250) # different activation
43+
assert get_output_shape(10, 50, 250, norm=LayerNormNC(250)) == (10, 50, 250) # different norm
44+
assert get_output_shape(1, 50, 100) == (1, 50, 100) # test with batch size 1
45+
assert get_output_shape(10, 1, 50) == (10, 1, 50) # time dim 1
46+
assert get_output_shape(10, 10, 20, dropout=0.0) == (10, 10, 20) # dropout 0
47+
assert get_output_shape(10, 10, 20, kernel_size=3) == (10, 10, 20) # odd kernel size
48+
assert get_output_shape(10, 10, 20, kernel_size=32) == (10, 10, 20) # even kernel size
49+
50+
51+
def test_layer_norm_nc():
52+
torch.manual_seed(42)
53+
54+
def get_output(shape, norm):
55+
x = torch.randn(shape)
56+
out = norm(x)
57+
return out
58+
59+
# test with different shape
60+
torch_ln = get_output([10, 50, 250], nn.LayerNorm(250))
61+
custom_ln = get_output([10, 250, 50], LayerNormNC(250))
62+
torch.allclose(torch_ln, custom_ln.transpose(1, 2))
63+
64+
torch_ln = get_output([10, 8, 23], nn.LayerNorm(23))
65+
custom_ln = get_output([10, 23, 8], LayerNormNC(23))
66+
torch.allclose(torch_ln, custom_ln.transpose(1, 2))

0 commit comments

Comments
 (0)