1
1
import unittest
2
+ import inspect
2
3
from functools import lru_cache
3
4
4
5
import torch
@@ -23,6 +24,18 @@ def model_type(self):
23
24
raise ValueError ("test_model_type is not set" )
24
25
return self .test_model_type
25
26
27
+ @property
28
+ def model_class (self ):
29
+ return smp .MODEL_ARCHITECTURES_MAPPING [self .model_type ]
30
+
31
+ @property
32
+ def decoder_channels (self ):
33
+ signature = inspect .signature (self .model_class )
34
+ # check if decoder_channels is in the signature
35
+ if "decoder_channels" in signature .parameters :
36
+ return signature .parameters ["decoder_channels" ].default
37
+ return None
38
+
26
39
@lru_cache
27
40
def _get_sample (self , batch_size = 1 , num_channels = 3 , height = 32 , width = 32 ):
28
41
return torch .rand (batch_size , num_channels , height , width )
@@ -50,3 +63,29 @@ def test_forward_backward(self):
50
63
51
64
# check backward pass
52
65
output .mean ().backward ()
66
+
67
+ def test_base_params_are_set (self , in_channels = 1 , depth = 3 , classes = 7 ):
68
+ kwargs = {}
69
+
70
+ if self .model_type in ["unet" , "unetplusplus" , "manet" ]:
71
+ kwargs = {"decoder_channels" : self .decoder_channels [:depth ]}
72
+
73
+ model = smp .create_model (
74
+ arch = self .model_type ,
75
+ encoder_depth = depth ,
76
+ in_channels = in_channels ,
77
+ classes = classes ,
78
+ ** kwargs ,
79
+ )
80
+ sample = self ._get_sample (
81
+ batch_size = self .default_batch_size ,
82
+ num_channels = in_channels ,
83
+ height = self .default_height ,
84
+ width = self .default_width ,
85
+ )
86
+
87
+ # check in channels correctly set
88
+ with torch .no_grad ():
89
+ output = model (sample )
90
+
91
+ self .assertEqual (output .shape [1 ], classes )
0 commit comments