2
2
3
3
import torch
4
4
5
- from segmentation_models_pytorch import create_model
5
+ import segmentation_models_pytorch as smp
6
6
from tests .utils import check_two_models_strictly_equal
7
7
8
8
11
11
def test_seg_models_before_after_use_norm (model_name , decoder_option ):
12
12
torch .manual_seed (42 )
13
13
with pytest .warns (DeprecationWarning ):
14
- model_decoder_batchnorm = create_model (
15
- model_name , "mobilenet_v2" , None , decoder_use_batchnorm = decoder_option
14
+ model_decoder_batchnorm = smp .create_model (
15
+ model_name ,
16
+ "mobilenet_v2" ,
17
+ encoder_weights = None ,
18
+ decoder_use_batchnorm = decoder_option ,
16
19
)
17
- torch .manual_seed (42 )
18
- model_decoder_norm = create_model (
20
+ model_decoder_norm = smp .create_model (
19
21
model_name ,
20
22
"mobilenet_v2" ,
21
- None ,
22
- decoder_use_batchnorm = None ,
23
+ encoder_weights = None ,
23
24
decoder_use_norm = decoder_option ,
24
25
)
25
26
27
+ model_decoder_norm .load_state_dict (model_decoder_batchnorm .state_dict ())
28
+
26
29
check_two_models_strictly_equal (
27
30
model_decoder_batchnorm , model_decoder_norm , torch .rand (1 , 3 , 224 , 224 )
28
31
)
@@ -32,17 +35,19 @@ def test_seg_models_before_after_use_norm(model_name, decoder_option):
32
35
def test_pspnet_before_after_use_norm (decoder_option ):
33
36
torch .manual_seed (42 )
34
37
with pytest .warns (DeprecationWarning ):
35
- model_decoder_batchnorm = create_model (
36
- "pspnet" , "mobilenet_v2" , None , psp_use_batchnorm = decoder_option
38
+ model_decoder_batchnorm = smp .create_model (
39
+ "pspnet" ,
40
+ "mobilenet_v2" ,
41
+ encoder_weights = None ,
42
+ psp_use_batchnorm = decoder_option ,
37
43
)
38
- torch .manual_seed (42 )
39
- model_decoder_norm = create_model (
44
+ model_decoder_norm = smp .create_model (
40
45
"pspnet" ,
41
46
"mobilenet_v2" ,
42
- None ,
43
- psp_use_batchnorm = None ,
47
+ encoder_weights = None ,
44
48
decoder_use_norm = decoder_option ,
45
49
)
50
+ model_decoder_norm .load_state_dict (model_decoder_batchnorm .state_dict ())
46
51
47
52
check_two_models_strictly_equal (
48
53
model_decoder_batchnorm , model_decoder_norm , torch .rand (1 , 3 , 224 , 224 )
0 commit comments