1+ from __future__ import annotations
12from itertools import product
23
34import torch
89 ConformerPositionwiseFeedForwardV1 ,
910 ConformerPositionwiseFeedForwardV1Config ,
1011)
12+ from i6_models .parts .conformer .mhsa import ConformerMHSAV1Config , ConformerMHSAV1
1113from i6_models .parts .conformer .norm import LayerNormNC
1214
1315
14- def test_ConformerPositionwiseFeedForwardV1 ():
15- def get_output_shape (input_shape , input_dim , hidden_dim , dropout , activation ):
16- x = torch .randn (input_shape )
17- cfg = ConformerPositionwiseFeedForwardV1Config (input_dim , hidden_dim , dropout , activation )
18- conf_ffn_part = ConformerPositionwiseFeedForwardV1 (cfg )
19- y = conf_ffn_part (x )
20- return y .shape
21-
22- for input_dim , hidden_dim , dropout , activation in product (
23- [10 , 20 ], [100 , 200 ], [0.1 , 0.3 ], [nn .functional .silu , nn .functional .relu ]
24- ):
25- input_shape = (10 , 100 , input_dim )
26- assert get_output_shape (input_shape , input_dim , hidden_dim , dropout , activation ) == input_shape
27-
28-
2916def test_conformer_convolution_output_shape ():
3017 def get_output_shape (batch , time , features , norm = None , kernel_size = 31 , dropout = 0.1 , activation = nn .functional .silu ):
3118 x = torch .randn (batch , time , features )
@@ -48,6 +35,40 @@ def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0
4835 assert get_output_shape (10 , 10 , 20 , kernel_size = 32 ) == (10 , 10 , 20 ) # even kernel size
4936
5037
38+ def test_ConformerPositionwiseFeedForwardV1 ():
39+ def get_output_shape (input_shape , input_dim , hidden_dim , dropout , activation ):
40+ x = torch .randn (input_shape )
41+ cfg = ConformerPositionwiseFeedForwardV1Config (input_dim , hidden_dim , dropout , activation )
42+ conf_ffn_part = ConformerPositionwiseFeedForwardV1 (cfg )
43+ y = conf_ffn_part (x )
44+ return y .shape
45+
46+ for input_dim , hidden_dim , dropout , activation in product (
47+ [10 , 20 ], [100 , 200 ], [0.1 , 0.3 ], [nn .functional .silu , nn .functional .relu ]
48+ ):
49+ input_shape = (10 , 100 , input_dim )
50+ assert get_output_shape (input_shape , input_dim , hidden_dim , dropout , activation ) == input_shape
51+
52+
53+ def test_ConformerMHSAV1 ():
54+ def get_output_shape (input_shape , cfg , ** kwargs ):
55+
56+ input = torch .randn (input_shape )
57+ output = ConformerMHSAV1 (cfg )(input , ** kwargs )
58+
59+ return list (output .shape )
60+
61+ # without key padding mask
62+ input_shape = [3 , 10 , 20 ] # B,T,F
63+ cfg = ConformerMHSAV1Config (20 , 4 , 0.1 , 0.1 )
64+ assert get_output_shape (input_shape , cfg ) == [3 , 10 , 20 ]
65+
66+ # with key padding mask
67+ input_shape = [4 , 15 , 32 ] # B,T,F
68+ cfg = ConformerMHSAV1Config (32 , 8 , 0.2 , 0.3 )
69+ assert get_output_shape (input_shape , cfg , key_padding_mask = torch .randint (0 , 2 , input_shape [:2 ]) > 0 ) == [4 , 15 , 32 ]
70+
71+
5172def test_layer_norm_nc ():
5273 torch .manual_seed (42 )
5374
0 commit comments