33import segmentation_models_pytorch as smp
44
55from functools import lru_cache
6-
6+ from tests . utils import default_device
77
88class BaseEncoderTester (unittest .TestCase ):
99 encoder_names = []
@@ -40,13 +40,13 @@ def test_forward_backward(self):
4040 num_channels = self .default_num_channels ,
4141 height = self .default_height ,
4242 width = self .default_width ,
43- )
43+ ). to ( default_device )
4444 for encoder_name in self .encoder_names :
4545 with self .subTest (encoder_name = encoder_name ):
4646 # init encoder
4747 encoder = smp .encoders .get_encoder (
4848 encoder_name , in_channels = 3 , encoder_weights = None
49- )
49+ ). to ( default_device )
5050
5151 # forward
5252 features = encoder .forward (sample )
@@ -72,12 +72,12 @@ def test_in_channels(self):
7272 num_channels = in_channels ,
7373 height = self .default_height ,
7474 width = self .default_width ,
75- )
75+ ). to ( default_device )
7676
7777 with self .subTest (encoder_name = encoder_name , in_channels = in_channels ):
7878 encoder = smp .encoders .get_encoder (
7979 encoder_name , in_channels = in_channels , encoder_weights = None
80- )
80+ ). to ( default_device )
8181 encoder .eval ()
8282
8383 # forward
@@ -90,7 +90,7 @@ def test_depth(self):
9090 num_channels = self .default_num_channels ,
9191 height = self .default_height ,
9292 width = self .default_width ,
93- )
93+ ). to ( default_device )
9494
9595 cases = [
9696 (encoder_name , depth )
@@ -105,7 +105,7 @@ def test_depth(self):
105105 in_channels = self .default_num_channels ,
106106 encoder_weights = None ,
107107 depth = depth ,
108- )
108+ ). to ( default_device )
109109 encoder .eval ()
110110
111111 # forward
@@ -154,7 +154,7 @@ def test_dilated(self):
154154 num_channels = self .default_num_channels ,
155155 height = self .default_height ,
156156 width = self .default_width ,
157- )
157+ ). to ( default_device )
158158
159159 cases = [
160160 (encoder_name , stride )
@@ -172,7 +172,7 @@ def test_dilated(self):
172172 in_channels = self .default_num_channels ,
173173 encoder_weights = None ,
174174 output_stride = stride ,
175- )
175+ ). to ( default_device )
176176 return
177177
178178 for encoder_name , stride in cases :
@@ -182,7 +182,7 @@ def test_dilated(self):
182182 in_channels = self .default_num_channels ,
183183 encoder_weights = None ,
184184 output_stride = stride ,
185- )
185+ ). to ( default_device )
186186 encoder .eval ()
187187
188188 # forward
0 commit comments