@@ -29,7 +29,7 @@ def setUpClass(cls):
2929
3030 ## Base
3131 ni_1 = Input (x_1_input_shape , name = 'test_ni1' )
32- nn_1 = Conv1d (out_channels = 32 , kernel_size = 5 , stride = 2 , name = 'test_conv1d' )(ni_1 )
32+ nn_1 = Conv1d (out_channels = 32 , in_channels = 1 , stride = 2 , name = 'test_conv1d' )(ni_1 )
3333 n1_b = BatchNorm (name = 'test_conv' )(nn_1 )
3434 cls .n1_b = n1_b
3535
@@ -47,7 +47,7 @@ class bn_0d_model(tlx.nn.Module):
4747
4848 def __init__ (self ):
4949 super (bn_0d_model , self ).__init__ ()
50- self .fc = Dense (32 , in_channels = 10 )
50+ self .fc = Linear (32 , in_features = 10 )
5151 self .bn = BatchNorm (num_features = 32 , name = 'test_bn1d' )
5252
5353 def forward (self , x ):
@@ -61,7 +61,7 @@ def forward(self, x):
6161
6262 nin_0 = Input (x_0_input_shape , name = 'test_in1' )
6363
64- n0 = Dense (32 )(nin_0 )
64+ n0 = Linear (32 )(nin_0 )
6565 n0 = BatchNorm1d (name = 'test_bn0d' )(n0 )
6666
6767 cls .n0 = n0
@@ -70,7 +70,7 @@ class bn_0d_model(tlx.nn.Module):
7070
7171 def __init__ (self ):
7272 super (bn_0d_model , self ).__init__ (name = 'test_bn_0d_model' )
73- self .fc = Dense (32 , in_channels = 10 )
73+ self .fc = Linear (32 , in_features = 10 )
7474 self .bn = BatchNorm1d (num_features = 32 , name = 'test_bn1d' )
7575
7676 def forward (self , x ):
@@ -104,7 +104,6 @@ def forward(self, x):
104104 ## 2D ========================================================================
105105
106106 nin_2 = Input (x_2_input_shape , name = 'test_in2' )
107-
108107 n2 = Conv2d (out_channels = 32 , kernel_size = (3 , 3 ), stride = (2 , 2 ), name = 'test_conv2d' )(nin_2 )
109108 n2 = BatchNorm2d (name = 'test_bn2d' )(n2 )
110109
@@ -126,7 +125,6 @@ def forward(self, x):
126125 ## 3D ========================================================================
127126
128127 nin_3 = Input (x_3_input_shape , name = 'test_in3' )
129-
130128 n3 = Conv3d (out_channels = 32 , kernel_size = (3 , 3 , 3 ), stride = (2 , 2 , 2 ), name = 'test_conv3d' )(nin_3 )
131129 n3 = BatchNorm3d (name = 'test_bn3d' , act = tlx .ReLU )(n3 )
132130
0 commit comments