Skip to content

Commit 2e76cd9

Browse files
Atticus1806Judyxujjalbertzpzhengmmz33
authored
Conformer Assembly (#8)
Add Conformer Assembly --------- Co-authored-by: Judyxujj <[email protected]> Co-authored-by: Albert Zeyer <[email protected]> Co-authored-by: pzheng <[email protected]> Co-authored-by: Mohammad Zeineldeen <[email protected]> Co-authored-by: Nick Rossenbach <[email protected]> Co-authored-by: vieting <[email protected]> Co-authored-by: michelwi <[email protected]>
1 parent 7e345a1 commit 2e76cd9

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

i6_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .conformer_v1 import *
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
__all__ = ["ConformerBlockV1Config", "ConformerEncoderV1Config", "ConformerBlockV1", "ConformerEncoderV1"]
4+
5+
import torch
6+
from torch import nn
7+
from dataclasses import dataclass
8+
from typing import Tuple
9+
10+
from i6_models.config import ModelConfiguration, ModuleFactoryV1
11+
from i6_models.parts.conformer import (
12+
ConformerConvolutionV1,
13+
ConformerConvolutionV1Config,
14+
ConformerMHSAV1,
15+
ConformerMHSAV1Config,
16+
ConformerPositionwiseFeedForwardV1,
17+
ConformerPositionwiseFeedForwardV1Config,
18+
)
19+
20+
21+
@dataclass
22+
class ConformerBlockV1Config(ModelConfiguration):
23+
"""
24+
Attributes:
25+
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV1
26+
mhsa_cfg: Configuration for ConformerMHSAV1
27+
conv_cfg: Configuration for ConformerConvolutionV1
28+
"""
29+
30+
# nested configurations
31+
ff_cfg: ConformerPositionwiseFeedForwardV1Config
32+
mhsa_cfg: ConformerMHSAV1Config
33+
conv_cfg: ConformerConvolutionV1Config
34+
35+
36+
class ConformerBlockV1(nn.Module):
37+
"""
38+
Conformer block module
39+
"""
40+
41+
def __init__(self, cfg: ConformerBlockV1Config):
42+
"""
43+
:param cfg: conformer block configuration with subunits for the different conformer parts
44+
"""
45+
super().__init__()
46+
self.ff1 = ConformerPositionwiseFeedForwardV1(cfg=cfg.ff_cfg)
47+
self.mhsa = ConformerMHSAV1(cfg=cfg.mhsa_cfg)
48+
self.conv = ConformerConvolutionV1(model_cfg=cfg.conv_cfg)
49+
self.ff2 = ConformerPositionwiseFeedForwardV1(cfg=cfg.ff_cfg)
50+
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)
51+
52+
def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
53+
"""
54+
:param x: input tensor of shape [B, T, F]
55+
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T]
56+
:return: torch.Tensor of shape [B, T, F]
57+
"""
58+
x = 0.5 * self.ff1(x) + x # [B, T, F]
59+
x = self.mhsa(x, sequence_mask) + x # [B, T, F]
60+
x = self.conv(x) + x # [B, T, F]
61+
x = 0.5 * self.ff2(x) + x # [B, T, F]
62+
x = self.final_layer_norm(x) # [B, T, F]
63+
return x
64+
65+
66+
@dataclass
67+
class ConformerEncoderV1Config(ModelConfiguration):
68+
"""
69+
Attributes:
70+
num_layers: Number of conformer layers in the conformer encoder
71+
frontend: A pair of ConformerFrontend and corresponding config
72+
block_cfg: Configuration for ConformerBlockV1
73+
"""
74+
75+
num_layers: int
76+
77+
# nested configurations
78+
frontend: ModuleFactoryV1
79+
block_cfg: ConformerBlockV1Config
80+
81+
82+
class ConformerEncoderV1(nn.Module):
83+
"""
84+
Implementation of the convolution-augmented Transformer (short Conformer), as in the original publication.
85+
The model consists of a frontend and a stack of N conformer blocks.
86+
C.f. https://arxiv.org/pdf/2005.08100.pdf
87+
"""
88+
89+
def __init__(self, cfg: ConformerEncoderV1Config):
90+
"""
91+
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks
92+
"""
93+
super().__init__()
94+
95+
self.frontend = cfg.frontend()
96+
self.module_list = torch.nn.ModuleList([ConformerBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])
97+
98+
def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
99+
"""
100+
:param data_tensor: input tensor of shape [B, T', F]
101+
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T']
102+
:return: (output, out_seq_mask)
103+
where output is torch.Tensor of shape [B, T, F'],
104+
out_seq_mask is a torch.Tensor of shape [B, T]
105+
106+
F: input feature dim, F': internal and output feature dim
107+
T': data time dim, T: down-sampled time dim (internal time dim)
108+
"""
109+
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F']
110+
for module in self.module_list:
111+
x = module(x, sequence_mask) # [B, T, F']
112+
113+
return x, sequence_mask

0 commit comments

Comments
 (0)