Skip to content

Commit 0f20821

Browse files
Judyxujjalbertz
andauthored
Implement ConformerFeedForwardV1 Part (#6)
Co-authored-by: Albert Zeyer <[email protected]>
1 parent afa01f6 commit 0f20821

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from __future__ import annotations
2+
from dataclasses import dataclass
3+
from typing import Callable
4+
5+
import torch
6+
from torch import nn
7+
8+
from i6_models.config import ModelConfiguration
9+
10+
11+
@dataclass
12+
class ConformerPositionwiseFeedForwardV1Config(ModelConfiguration):
13+
input_dim: int
14+
"""input dimension"""
15+
hidden_dim: int
16+
"""hidden dimension (normally set to 4*input_dim as suggested by the paper)"""
17+
dropout: float
18+
"""dropout probability"""
19+
activation: Callable[[torch.Tensor], torch.Tensor] = nn.functional.silu
20+
"""activation function"""
21+
22+
23+
class ConformerPositionwiseFeedForwardV1(nn.Module):
24+
"""
25+
Conformer feedforward module
26+
"""
27+
28+
def __init__(self, cfg: ConformerPositionwiseFeedForwardV1Config):
29+
super().__init__()
30+
31+
self.layer_norm = nn.LayerNorm(cfg.input_dim)
32+
self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.hidden_dim, bias=True)
33+
self.activation = cfg.activation
34+
self.linear_out = nn.Linear(in_features=cfg.hidden_dim, out_features=cfg.input_dim, bias=True)
35+
self.dropout = cfg.dropout
36+
37+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
38+
"""
39+
:param tensor: shape [B,T,F], F=input_dim
40+
:return: shape [B,T,F], F=input_dim
41+
"""
42+
tensor = self.layer_norm(tensor)
43+
tensor = self.linear_ff(tensor) # [B,T,F]
44+
tensor = self.activation(tensor) # [B,T,F]
45+
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]
46+
tensor = self.linear_out(tensor) # [B,T,F]
47+
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]
48+
return tensor

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
typeguard
1+
typeguard
2+
torch

tests/test_conformer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from itertools import product
2+
3+
import torch
4+
from torch import nn
5+
6+
from i6_models.parts.conformer.feedforward import (
7+
ConformerPositionwiseFeedForwardV1,
8+
ConformerPositionwiseFeedForwardV1Config,
9+
)
10+
11+
12+
def test_ConformerPositionwiseFeedForwardV1():
13+
def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation):
14+
x = torch.randn(input_shape)
15+
cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation)
16+
conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg)
17+
y = conf_ffn_part(x)
18+
return y.shape
19+
20+
for input_dim, hidden_dim, dropout, activation in product(
21+
[10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu]
22+
):
23+
input_shape = (10, 100, input_dim)
24+
assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape

0 commit comments

Comments
 (0)