Skip to content

Commit 1264482

Browse files
Add helper function applying correct padding for even kernel sizes (#54)
1 parent 5e4d013 commit 1264482

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

i6_models/parts/frontend/common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,30 @@ def get_same_padding(input_size: Union[int, Tuple[int, ...]]) -> Union[int, Tupl
2020
raise TypeError(f"unexpected size type {type(input_size)}")
2121

2222

23+
def apply_same_padding(x: torch.Tensor, kernel_size: Union[int, Tuple[int, ...]], **kwargs) -> torch.Tensor:
24+
"""
25+
Pad tensor almost symmetrically in one or more dimensions in order to not reduce time dimension
26+
when applying convolution with the given kernel. As opposed to the standard padding parameter
27+
this also handles even kernel sizes.
28+
29+
:param x:
30+
:param kernel_size: kernel size of the convolution for which the tensor is padded
31+
:param kwargs: keyword args passed to functional.pad
32+
:return: padded tensor
33+
"""
34+
if isinstance(kernel_size, int):
35+
h = (kernel_size - 1) // 2
36+
return functional.pad(x, (h, kernel_size - 1 - h), **kwargs)
37+
elif isinstance(kernel_size, tuple):
38+
paddings = ()
39+
for k in reversed(kernel_size): # padding function starts with last dim
40+
h = (k - 1) // 2
41+
paddings += (h, k - 1 - h)
42+
return functional.pad(x, paddings, **kwargs)
43+
else:
44+
raise TypeError(f"Unexpected size type {type(kernel_size)}")
45+
46+
2347
def mask_pool(seq_mask: torch.Tensor, *, kernel_size: int, stride: int, padding: int) -> torch.Tensor:
2448
"""
2549
apply strides to the masking

tests/test_padding.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from itertools import product
2+
3+
import torch
4+
from torch import nn
5+
6+
from i6_models.parts.frontend.common import apply_same_padding, get_same_padding
7+
8+
9+
def test_output_shape():
10+
# test for even and odd dim
11+
last_dim = 101
12+
pre_last_dim = 100
13+
14+
iff = lambda x, y: x and y or not x and not y # x <=> y
15+
strided_dim = lambda d, s: (d - 1) // s + 1 # expected out dimension for strided conv
16+
17+
# `get_same_padding` seems to work for some stride > 1
18+
for kernel in product(range(1, 21), repeat=2):
19+
conv = nn.Conv2d(1, 1, kernel_size=kernel, stride=(1, 1), padding=get_same_padding(kernel))
20+
21+
x = torch.randn(1, 1, pre_last_dim, last_dim)
22+
23+
out = conv(x)
24+
25+
# we expect `get_same_padding` to only cover odd kernel sizes
26+
assert all(
27+
iff(out_dim == in_dim, k % 2 == 1) for in_dim, out_dim, k in zip(x.shape[2:], out.shape[2:], kernel)
28+
), f"Failed for {x.shape=}, {out.shape=}, {kernel=} and stride=1"
29+
30+
for kernel, stride in product(product(range(1, 21), repeat=2), range(1, 7)):
31+
conv = nn.Conv2d(1, 1, kernel_size=kernel, stride=(1, stride))
32+
33+
x = torch.randn(1, 1, pre_last_dim, last_dim)
34+
x_padded = apply_same_padding(x, kernel)
35+
36+
out = conv(x_padded)
37+
38+
# correct out dimensions for all possible kernel sizes and strides
39+
assert all(
40+
out_dim == strided_dim(in_dim, s)
41+
for in_dim, out_dim, k, s in zip(x.shape[2:], out.shape[2:], kernel, (1, stride))
42+
), f"Failed for {x.shape=}, {out.shape=}, {kernel=} and {stride=}"

0 commit comments

Comments
 (0)