Skip to content

Commit 83af04d

Browse files
kuacakuacapzhengEugen Beckmichelwi
authored
Add MHSA module (#7)
Co-authored-by: pzheng <[email protected]> Co-authored-by: Eugen Beck <[email protected]> Co-authored-by: michelwi <[email protected]>
1 parent 147c00c commit 83af04d

File tree

2 files changed

+90
-15
lines changed

2 files changed

+90
-15
lines changed

i6_models/parts/conformer/mhsa.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from __future__ import annotations
2+
from dataclasses import dataclass
3+
from typing import Optional, Callable
4+
import torch
5+
6+
from i6_models.config import ModelConfiguration
7+
8+
9+
@dataclass
10+
class ConformerMHSAV1Config(ModelConfiguration):
11+
input_dim: int
12+
"""input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads`"""
13+
num_att_heads: int
14+
"""number of attention heads"""
15+
att_weights_dropout: float
16+
"""attention weights dropout"""
17+
dropout: float
18+
"""multi-headed self attention output dropout"""
19+
20+
def __post_init__(self) -> None:
21+
super().__post_init__()
22+
assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads"
23+
24+
25+
class ConformerMHSAV1(torch.nn.Module):
26+
"""
27+
Conformer multi-headed self-attention module
28+
"""
29+
30+
def __init__(self, cfg: ConformerMHSAV1Config):
31+
32+
super().__init__()
33+
34+
self.layernorm = torch.nn.LayerNorm(cfg.input_dim)
35+
self.mhsa = torch.nn.MultiheadAttention(
36+
cfg.input_dim, cfg.num_att_heads, dropout=cfg.att_weights_dropout, batch_first=True
37+
)
38+
self.dropout = cfg.dropout
39+
40+
def forward(self, input_tensor: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
41+
"""
42+
Apply layer norm and multi-head self attention and dropout
43+
:param Optional[torch.Tensor] key_padding_mask: could be a binary or float mask of shape (B, T)
44+
which will be applied/added to dot product, used to mask padded key positions out
45+
"""
46+
47+
output_tensor = self.layernorm(input_tensor) # [B,T,F]
48+
49+
output_tensor, _ = self.mhsa(
50+
output_tensor, output_tensor, output_tensor, key_padding_mask=key_padding_mask, need_weights=False
51+
) # [B,T,F]
52+
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]
53+
54+
return output_tensor

tests/test_conformer.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from itertools import product
23

34
import torch
@@ -8,24 +9,10 @@
89
ConformerPositionwiseFeedForwardV1,
910
ConformerPositionwiseFeedForwardV1Config,
1011
)
12+
from i6_models.parts.conformer.mhsa import ConformerMHSAV1Config, ConformerMHSAV1
1113
from i6_models.parts.conformer.norm import LayerNormNC
1214

1315

14-
def test_ConformerPositionwiseFeedForwardV1():
15-
def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation):
16-
x = torch.randn(input_shape)
17-
cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation)
18-
conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg)
19-
y = conf_ffn_part(x)
20-
return y.shape
21-
22-
for input_dim, hidden_dim, dropout, activation in product(
23-
[10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu]
24-
):
25-
input_shape = (10, 100, input_dim)
26-
assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape
27-
28-
2916
def test_conformer_convolution_output_shape():
3017
def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0.1, activation=nn.functional.silu):
3118
x = torch.randn(batch, time, features)
@@ -48,6 +35,40 @@ def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0
4835
assert get_output_shape(10, 10, 20, kernel_size=32) == (10, 10, 20) # even kernel size
4936

5037

38+
def test_ConformerPositionwiseFeedForwardV1():
39+
def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation):
40+
x = torch.randn(input_shape)
41+
cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation)
42+
conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg)
43+
y = conf_ffn_part(x)
44+
return y.shape
45+
46+
for input_dim, hidden_dim, dropout, activation in product(
47+
[10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu]
48+
):
49+
input_shape = (10, 100, input_dim)
50+
assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape
51+
52+
53+
def test_ConformerMHSAV1():
54+
def get_output_shape(input_shape, cfg, **kwargs):
55+
56+
input = torch.randn(input_shape)
57+
output = ConformerMHSAV1(cfg)(input, **kwargs)
58+
59+
return list(output.shape)
60+
61+
# without key padding mask
62+
input_shape = [3, 10, 20] # B,T,F
63+
cfg = ConformerMHSAV1Config(20, 4, 0.1, 0.1)
64+
assert get_output_shape(input_shape, cfg) == [3, 10, 20]
65+
66+
# with key padding mask
67+
input_shape = [4, 15, 32] # B,T,F
68+
cfg = ConformerMHSAV1Config(32, 8, 0.2, 0.3)
69+
assert get_output_shape(input_shape, cfg, key_padding_mask=torch.randint(0, 2, input_shape[:2]) > 0) == [4, 15, 32]
70+
71+
5172
def test_layer_norm_nc():
5273
torch.manual_seed(42)
5374

0 commit comments

Comments
 (0)