33import torch
44from torch import nn
55
6+ from i6_models .parts .conformer .convolution import ConformerConvolutionV1 , ConformerConvolutionV1Config
67from i6_models .parts .conformer .feedforward import (
78 ConformerPositionwiseFeedForwardV1 ,
89 ConformerPositionwiseFeedForwardV1Config ,
910)
11+ from i6_models .parts .conformer .norm import LayerNormNC
1012
1113
1214def test_ConformerPositionwiseFeedForwardV1 ():
@@ -22,3 +24,43 @@ def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation):
2224 ):
2325 input_shape = (10 , 100 , input_dim )
2426 assert get_output_shape (input_shape , input_dim , hidden_dim , dropout , activation ) == input_shape
27+
28+
29+ def test_conformer_convolution_output_shape ():
30+ def get_output_shape (batch , time , features , norm = None , kernel_size = 31 , dropout = 0.1 , activation = nn .functional .silu ):
31+ x = torch .randn (batch , time , features )
32+ if norm is None :
33+ norm = nn .BatchNorm1d (features )
34+ cfg = ConformerConvolutionV1Config (
35+ channels = features , kernel_size = kernel_size , dropout = dropout , activation = activation , norm = norm
36+ )
37+ conformer_conv_part = ConformerConvolutionV1 (cfg )
38+ y = conformer_conv_part (x )
39+ return y .shape
40+
41+ assert get_output_shape (10 , 50 , 250 ) == (10 , 50 , 250 )
42+ assert get_output_shape (10 , 50 , 250 , activation = nn .functional .relu ) == (10 , 50 , 250 ) # different activation
43+ assert get_output_shape (10 , 50 , 250 , norm = LayerNormNC (250 )) == (10 , 50 , 250 ) # different norm
44+ assert get_output_shape (1 , 50 , 100 ) == (1 , 50 , 100 ) # test with batch size 1
45+ assert get_output_shape (10 , 1 , 50 ) == (10 , 1 , 50 ) # time dim 1
46+ assert get_output_shape (10 , 10 , 20 , dropout = 0.0 ) == (10 , 10 , 20 ) # dropout 0
47+ assert get_output_shape (10 , 10 , 20 , kernel_size = 3 ) == (10 , 10 , 20 ) # odd kernel size
48+ assert get_output_shape (10 , 10 , 20 , kernel_size = 32 ) == (10 , 10 , 20 ) # even kernel size
49+
50+
51+ def test_layer_norm_nc ():
52+ torch .manual_seed (42 )
53+
54+ def get_output (shape , norm ):
55+ x = torch .randn (shape )
56+ out = norm (x )
57+ return out
58+
59+ # test with different shape
60+ torch_ln = get_output ([10 , 50 , 250 ], nn .LayerNorm (250 ))
61+ custom_ln = get_output ([10 , 250 , 50 ], LayerNormNC (250 ))
62+ torch .allclose (torch_ln , custom_ln .transpose (1 , 2 ))
63+
64+ torch_ln = get_output ([10 , 8 , 23 ], nn .LayerNorm (23 ))
65+ custom_ln = get_output ([10 , 23 , 8 ], LayerNormNC (23 ))
66+ torch .allclose (torch_ln , custom_ln .transpose (1 , 2 ))
0 commit comments