11import unittest
2+ import inspect
23from functools import lru_cache
34
45import torch
@@ -23,6 +24,18 @@ def model_type(self):
2324 raise ValueError ("test_model_type is not set" )
2425 return self .test_model_type
2526
27+ @property
28+ def model_class (self ):
29+ return smp .MODEL_ARCHITECTURES_MAPPING [self .model_type ]
30+
31+ @property
32+ def decoder_channels (self ):
33+ signature = inspect .signature (self .model_class )
34+ # check if decoder_channels is in the signature
35+ if "decoder_channels" in signature .parameters :
36+ return signature .parameters ["decoder_channels" ].default
37+ return None
38+
2639 @lru_cache
2740 def _get_sample (self , batch_size = 1 , num_channels = 3 , height = 32 , width = 32 ):
2841 return torch .rand (batch_size , num_channels , height , width )
@@ -50,3 +63,29 @@ def test_forward_backward(self):
5063
5164 # check backward pass
5265 output .mean ().backward ()
66+
67+ def test_base_params_are_set (self , in_channels = 1 , depth = 3 , classes = 7 ):
68+ kwargs = {}
69+
70+ if self .model_type in ["unet" , "unetplusplus" , "manet" ]:
71+ kwargs = {"decoder_channels" : self .decoder_channels [:depth ]}
72+
73+ model = smp .create_model (
74+ arch = self .model_type ,
75+ encoder_depth = depth ,
76+ in_channels = in_channels ,
77+ classes = classes ,
78+ ** kwargs ,
79+ )
80+ sample = self ._get_sample (
81+ batch_size = self .default_batch_size ,
82+ num_channels = in_channels ,
83+ height = self .default_height ,
84+ width = self .default_width ,
85+ )
86+
87+ # check in channels correctly set
88+ with torch .no_grad ():
89+ output = model (sample )
90+
91+ self .assertEqual (output .shape [1 ], classes )
0 commit comments