Skip to content

Commit 83ff39e

Browse files
DanEnergeticsDaniel MannmichelwiAtticus1806
authored
Feed forward layer, frontend and encoder (#53)
--------- Co-authored-by: Daniel Mann <[email protected]> Co-authored-by: michelwi <[email protected]> Co-authored-by: Benedikt Hilmes <[email protected]>
1 parent 1264482 commit 83ff39e

File tree

6 files changed

+296
-0
lines changed

6 files changed

+296
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ffnn_v1 import *
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
__all__ = ["FeedForwardEncoderV1Config", "FeedForwardEncoderV1"]
2+
3+
from typing import Tuple
4+
from dataclasses import dataclass
5+
import torch
6+
from torch import nn
7+
8+
from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config
9+
from i6_models.config import ModelConfiguration, ModuleFactoryV1
10+
11+
12+
@dataclass
13+
class FeedForwardEncoderV1Config(ModelConfiguration):
14+
"""
15+
Attributes:
16+
num_layers: number of feed-forward layers
17+
frontend: module factory for the frontend
18+
layer_cfg: configuration object for each feed-forward layer
19+
"""
20+
21+
num_layers: int
22+
frontend: ModuleFactoryV1
23+
layer_cfg: FeedForwardLayerV1Config
24+
25+
26+
class FeedForwardEncoderV1(nn.Module):
27+
"""
28+
Simple feed-forward encoder.
29+
Subsampling can be achieved by setting stride > 1 in the frontend config.
30+
"""
31+
32+
def __init__(self, cfg: FeedForwardEncoderV1Config):
33+
super().__init__()
34+
self.frontend = cfg.frontend()
35+
self.module_list = nn.ModuleList([FeedForwardLayerV1(cfg.layer_cfg) for _ in range(cfg.num_layers)])
36+
37+
def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
38+
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F']
39+
for module in self.module_list:
40+
x, sequence_mask = module(x, sequence_mask) # [B, T, F']
41+
42+
return x, sequence_mask

i6_models/parts/ffnn.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
__all__ = ["FeedForwardConfig", "FeedForwardModel"]
2+
3+
from dataclasses import dataclass
4+
from functools import partial
5+
from typing import Callable, Optional, Tuple, Union
6+
7+
import torch
8+
from torch import nn
9+
import torch.nn.functional as F
10+
11+
from i6_models.config import ModelConfiguration
12+
13+
14+
@dataclass
15+
class FeedForwardLayerV1Config(ModelConfiguration):
16+
"""
17+
Attributes:
18+
input_dim: input feature dimension
19+
output_dim: output feature dimension
20+
dropout: dropout probability
21+
activation: activation function applied after linear computation
22+
"""
23+
24+
input_dim: int
25+
output_dim: int
26+
dropout: float
27+
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
28+
29+
def __post_init__(self):
30+
super().__post_init__()
31+
assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability"
32+
33+
34+
class FeedForwardLayerV1(nn.Module):
35+
"""
36+
Simple feed-forward layer module consisting of:
37+
- linear
38+
- activation
39+
- dropout
40+
"""
41+
42+
def __init__(self, cfg: FeedForwardLayerV1Config):
43+
super().__init__()
44+
self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True)
45+
self.activation = cfg.activation
46+
self.dropout = nn.Dropout(cfg.dropout)
47+
48+
def forward(
49+
self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
50+
) -> Tuple[torch.Tensor, torch.Tensor]:
51+
"""
52+
:param tensor: shape [B,T,F], F=input_dim
53+
:param sequence_mask: shape [B,T]
54+
:return: shape [B,T,F'], F'=output_dim
55+
"""
56+
tensor = self.linear_ff(tensor) # [B,T,F]
57+
tensor = self.activation(tensor) # [B,T,F]
58+
tensor = self.dropout(tensor) # [B,T,F]
59+
return tensor, sequence_mask
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
__all__ = [
2+
"WindowConvolutionFrontendV1Config",
3+
"WindowConvolutionFrontendV1",
4+
]
5+
6+
from dataclasses import dataclass
7+
from typing import Callable, Optional, Tuple, Union
8+
9+
import torch
10+
from torch import nn
11+
from torch.nn import functional as F
12+
13+
from i6_models.config import ModelConfiguration
14+
15+
from .common import mask_pool, apply_same_padding
16+
17+
18+
@dataclass
19+
class WindowConvolutionFrontendV1Config(ModelConfiguration):
20+
"""
21+
Attributes:
22+
input_dim: number of input features to module
23+
output_dim: output dimension
24+
dropout: dropout after linear layer
25+
kernel_size: number of feature frames to convolve
26+
stride: skip (stride - 1) feature frames; stride > 1 implies subsampling
27+
activation: activation function applied after linear computation
28+
"""
29+
30+
input_dim: int
31+
output_dim: int
32+
dropout: float
33+
kernel_size: int
34+
stride: int
35+
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
36+
37+
def __post_init__(self):
38+
super().__post_init__()
39+
assert self.stride >= 1, "Choose an integer >= 1 for stride"
40+
assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability"
41+
42+
43+
class WindowConvolutionFrontendV1(nn.Module):
44+
"""
45+
Simple feed-forward front-end that computes over a window
46+
of input features. Choosing a stride > 1 allows for subsampling
47+
of the features.
48+
"""
49+
50+
def __init__(self, cfg: WindowConvolutionFrontendV1Config):
51+
"""
52+
:param cfg: model configuration for this module
53+
"""
54+
super().__init__()
55+
self.conv = torch.nn.Conv1d(
56+
in_channels=cfg.input_dim,
57+
out_channels=cfg.output_dim,
58+
kernel_size=cfg.kernel_size,
59+
stride=cfg.stride,
60+
padding=0,
61+
bias=True,
62+
)
63+
self.activation = cfg.activation
64+
self.pad = lambda x: apply_same_padding(x, cfg.kernel_size)
65+
self.dropout = torch.nn.Dropout(cfg.dropout)
66+
67+
def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
68+
"""
69+
T might be reduced to T' on stride
70+
71+
:param x: input tensor of shape [B,T,F]
72+
:param sequence_mask: the sequence mask for the tensor
73+
:return: torch.Tensor of shape [B,T',F'] and the shape of the sequence mask
74+
"""
75+
# torch 1d convolution is over last dim but we want time conv
76+
x = x.transpose(1, 2) # [B, F, T]
77+
x = self.pad(x)
78+
x = self.conv(x).transpose(1, 2) # [B, T', F']
79+
80+
# change masking according to stride value
81+
sequence_mask = mask_pool(
82+
sequence_mask,
83+
kernel_size=1,
84+
stride=self.conv.stride[0],
85+
padding=0, # done manually
86+
)
87+
x = self.activation(x)
88+
x = self.dropout(x)
89+
90+
return x, sequence_mask

tests/test_ffnn.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from itertools import product
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import functional as F
6+
7+
from i6_models.assemblies.ffnn import (
8+
FeedForwardEncoderV1,
9+
FeedForwardEncoderV1Config,
10+
)
11+
12+
from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1
13+
14+
from i6_models.config import ModelConfiguration, ModuleFactoryV1
15+
from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config
16+
17+
18+
def test_output_shape():
19+
input_dim = 80
20+
output_dim = 2048
21+
dropout = 0.1
22+
max_seq_lens = 100
23+
24+
for window_size, stride in product(range(1, 22), range(1, 5)):
25+
frontend = ModuleFactoryV1(
26+
WindowConvolutionFrontendV1,
27+
WindowConvolutionFrontendV1Config(
28+
input_dim=80,
29+
output_dim=output_dim,
30+
kernel_size=window_size,
31+
dropout=dropout,
32+
stride=stride,
33+
activation=F.relu,
34+
),
35+
)
36+
37+
layer_cfg = FeedForwardLayerV1Config(
38+
input_dim=2048,
39+
output_dim=2048,
40+
dropout=0.1,
41+
activation=F.relu,
42+
)
43+
44+
encoder_cfg = FeedForwardEncoderV1Config(num_layers=6, layer_cfg=layer_cfg, frontend=frontend)
45+
46+
encoder = FeedForwardEncoderV1(encoder_cfg)
47+
48+
feat_len = torch.arange(start=1, end=max_seq_lens + 1)
49+
mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None])
50+
51+
features = torch.empty((max_seq_lens, max_seq_lens, input_dim))
52+
53+
out, out_mask = encoder(features, mask)
54+
55+
expected_out_len = (feat_len - 1) // stride + 1
56+
expected_shape = (max_seq_lens, expected_out_len[-1], output_dim)
57+
assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}"
58+
for i in range(expected_out_len[-1] - 1):
59+
# check if masks are correct
60+
assert (
61+
out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]]
62+
), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}"

tests/test_window_frontend.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from itertools import product
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import functional as F
6+
7+
from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1
8+
9+
10+
def test_output_shape():
11+
in_features = 80
12+
out_features = 2048
13+
dropout = 0.1
14+
max_seq_lens = 100
15+
16+
for window_size, stride in product(range(1, 22), range(1, 5)):
17+
frontend = WindowConvolutionFrontendV1(
18+
WindowConvolutionFrontendV1Config(
19+
input_dim=80,
20+
output_dim=out_features,
21+
kernel_size=window_size,
22+
dropout=dropout,
23+
stride=stride,
24+
activation=F.relu,
25+
)
26+
)
27+
28+
feat_len = torch.arange(start=1, end=max_seq_lens + 1)
29+
mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None])
30+
31+
features = torch.empty((max_seq_lens, max_seq_lens, in_features))
32+
33+
out, out_mask = frontend(features, mask)
34+
35+
expected_out_len = (feat_len - 1) // stride + 1
36+
expected_shape = (max_seq_lens, expected_out_len[-1], out_features)
37+
assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}"
38+
for i in range(expected_out_len[-1] - 1):
39+
# check if masks are correct
40+
assert (
41+
out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]]
42+
), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}"

0 commit comments

Comments
 (0)