Skip to content

Commit 5ce9542

Browse files
committed
Merge branch 'main' into zeineldeen_att_decoder
2 parents 0ac8e69 + d2c8a24 commit 5ce9542

File tree

14 files changed

+949
-12
lines changed

14 files changed

+949
-12
lines changed

i6_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .conformer_v1 import *
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
__all__ = ["ConformerBlockV1Config", "ConformerEncoderV1Config", "ConformerBlockV1", "ConformerEncoderV1"]
4+
5+
import torch
6+
from torch import nn
7+
from dataclasses import dataclass
8+
from typing import Tuple
9+
10+
from i6_models.config import ModelConfiguration, ModuleFactoryV1
11+
from i6_models.parts.conformer import (
12+
ConformerConvolutionV1,
13+
ConformerConvolutionV1Config,
14+
ConformerMHSAV1,
15+
ConformerMHSAV1Config,
16+
ConformerPositionwiseFeedForwardV1,
17+
ConformerPositionwiseFeedForwardV1Config,
18+
)
19+
20+
21+
@dataclass
22+
class ConformerBlockV1Config(ModelConfiguration):
23+
"""
24+
Attributes:
25+
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV1
26+
mhsa_cfg: Configuration for ConformerMHSAV1
27+
conv_cfg: Configuration for ConformerConvolutionV1
28+
"""
29+
30+
# nested configurations
31+
ff_cfg: ConformerPositionwiseFeedForwardV1Config
32+
mhsa_cfg: ConformerMHSAV1Config
33+
conv_cfg: ConformerConvolutionV1Config
34+
35+
36+
class ConformerBlockV1(nn.Module):
37+
"""
38+
Conformer block module
39+
"""
40+
41+
def __init__(self, cfg: ConformerBlockV1Config):
42+
"""
43+
:param cfg: conformer block configuration with subunits for the different conformer parts
44+
"""
45+
super().__init__()
46+
self.ff1 = ConformerPositionwiseFeedForwardV1(cfg=cfg.ff_cfg)
47+
self.mhsa = ConformerMHSAV1(cfg=cfg.mhsa_cfg)
48+
self.conv = ConformerConvolutionV1(model_cfg=cfg.conv_cfg)
49+
self.ff2 = ConformerPositionwiseFeedForwardV1(cfg=cfg.ff_cfg)
50+
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)
51+
52+
def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
53+
"""
54+
:param x: input tensor of shape [B, T, F]
55+
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T]
56+
:return: torch.Tensor of shape [B, T, F]
57+
"""
58+
x = 0.5 * self.ff1(x) + x # [B, T, F]
59+
x = self.mhsa(x, sequence_mask) + x # [B, T, F]
60+
x = self.conv(x) + x # [B, T, F]
61+
x = 0.5 * self.ff2(x) + x # [B, T, F]
62+
x = self.final_layer_norm(x) # [B, T, F]
63+
return x
64+
65+
66+
@dataclass
67+
class ConformerEncoderV1Config(ModelConfiguration):
68+
"""
69+
Attributes:
70+
num_layers: Number of conformer layers in the conformer encoder
71+
frontend: A pair of ConformerFrontend and corresponding config
72+
block_cfg: Configuration for ConformerBlockV1
73+
"""
74+
75+
num_layers: int
76+
77+
# nested configurations
78+
frontend: ModuleFactoryV1
79+
block_cfg: ConformerBlockV1Config
80+
81+
82+
class ConformerEncoderV1(nn.Module):
83+
"""
84+
Implementation of the convolution-augmented Transformer (short Conformer), as in the original publication.
85+
The model consists of a frontend and a stack of N conformer blocks.
86+
C.f. https://arxiv.org/pdf/2005.08100.pdf
87+
"""
88+
89+
def __init__(self, cfg: ConformerEncoderV1Config):
90+
"""
91+
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks
92+
"""
93+
super().__init__()
94+
95+
self.frontend = cfg.frontend()
96+
self.module_list = torch.nn.ModuleList([ConformerBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])
97+
98+
def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
99+
"""
100+
:param data_tensor: input tensor of shape [B, T', F]
101+
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T']
102+
:return: (output, out_seq_mask)
103+
where output is torch.Tensor of shape [B, T, F'],
104+
out_seq_mask is a torch.Tensor of shape [B, T]
105+
106+
F: input feature dim, F': internal and output feature dim
107+
T': data time dim, T: down-sampled time dim (internal time dim)
108+
"""
109+
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F']
110+
for module in self.module_list:
111+
x = module(x, sequence_mask) # [B, T, F']
112+
113+
return x, sequence_mask

i6_models/parts/conformer/mhsa.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config"]
44
from dataclasses import dataclass
5-
from typing import Optional
65
import torch
76

87
from i6_models.config import ModelConfiguration
8+
from i6_models.util import compat
99

1010

1111
@dataclass
@@ -43,17 +43,19 @@ def __init__(self, cfg: ConformerMHSAV1Config):
4343
)
4444
self.dropout = cfg.dropout
4545

46-
def forward(self, input_tensor: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
46+
def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
4747
"""
4848
Apply layer norm and multi-head self attention and dropout
49-
:param Optional[torch.Tensor] key_padding_mask: could be a binary or float mask of shape (B, T)
49+
50+
:param input_tensor: Input to the self attention of shape (B, T, F)
51+
:param sequence_mask: bool mask of shape (B, T), True signals within sequence, False outside, will be inverted to match the torch.nn.MultiheadAttention module
5052
which will be applied/added to dot product, used to mask padded key positions out
5153
"""
52-
54+
inv_sequence_mask = compat.logical_not(sequence_mask)
5355
output_tensor = self.layernorm(input_tensor) # [B,T,F]
5456

5557
output_tensor, _ = self.mhsa(
56-
output_tensor, output_tensor, output_tensor, key_padding_mask=key_padding_mask, need_weights=False
58+
output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False
5759
) # [B,T,F]
5860
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]
5961

i6_models/parts/frontend/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Different front-ends for acoustic encoders
2+
3+
### Contributing
4+
5+
If you want to add your own front-end:
6+
7+
- Normally two classes are required. A config class and a model class
8+
- `Config` class inherits from `ModelConfiguration`
9+
- `Model` class inherits from `nn.Module` from `torch`
10+
- `forward(tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]`
11+
- `sequence_mask` is a boolean tensor where `True` means is inside the sequence and `False` is masked.
12+
- Please add tests

i6_models/parts/frontend/__init__.py

Whitespace-only changes.

i6_models/parts/frontend/common.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Tuple, Union
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import functional
6+
7+
8+
def get_same_padding(input_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
9+
"""
10+
get padding in order to not reduce the time dimension
11+
12+
:param input_size:
13+
:return:
14+
"""
15+
if isinstance(input_size, int):
16+
return (input_size - 1) // 2
17+
elif isinstance(input_size, tuple):
18+
return tuple((s - 1) // 2 for s in input_size)
19+
else:
20+
raise TypeError(f"unexpected size type {type(input_size)}")
21+
22+
23+
def mask_pool(seq_mask: torch.Tensor, *, kernel_size: int, stride: int, padding: int) -> torch.Tensor:
24+
"""
25+
apply strides to the masking
26+
27+
:param seq_mask: [B,T]
28+
:param kernel_size:
29+
:param stride:
30+
:param padding:
31+
:return: [B,T'] using maxpool
32+
"""
33+
if stride == 1 and 2 * padding == kernel_size - 1:
34+
return seq_mask
35+
36+
seq_mask = seq_mask.float()
37+
seq_mask = torch.unsqueeze(seq_mask, 1) # [B,1,T]
38+
seq_mask = nn.functional.max_pool1d(seq_mask, kernel_size, stride, padding) # [B,1,T']
39+
seq_mask = torch.squeeze(seq_mask, 1) # [B,T']
40+
seq_mask = seq_mask.bool()
41+
return seq_mask
42+
43+
44+
def calculate_output_dim(in_dim: int, *, filter_size: int, stride: int, padding: int) -> int:
45+
def ceildiv(a: int, b: int):
46+
return -(-a // b)
47+
48+
return ceildiv(in_dim + 2 * padding - (filter_size - 1) * 1, stride)

0 commit comments

Comments
 (0)