3
3
import segmentation_models_pytorch as smp
4
4
5
5
from functools import lru_cache
6
-
6
+ from tests . utils import default_device
7
7
8
8
class BaseEncoderTester (unittest .TestCase ):
9
9
encoder_names = []
@@ -40,13 +40,13 @@ def test_forward_backward(self):
40
40
num_channels = self .default_num_channels ,
41
41
height = self .default_height ,
42
42
width = self .default_width ,
43
- )
43
+ ). to ( default_device )
44
44
for encoder_name in self .encoder_names :
45
45
with self .subTest (encoder_name = encoder_name ):
46
46
# init encoder
47
47
encoder = smp .encoders .get_encoder (
48
48
encoder_name , in_channels = 3 , encoder_weights = None
49
- )
49
+ ). to ( default_device )
50
50
51
51
# forward
52
52
features = encoder .forward (sample )
@@ -72,12 +72,12 @@ def test_in_channels(self):
72
72
num_channels = in_channels ,
73
73
height = self .default_height ,
74
74
width = self .default_width ,
75
- )
75
+ ). to ( default_device )
76
76
77
77
with self .subTest (encoder_name = encoder_name , in_channels = in_channels ):
78
78
encoder = smp .encoders .get_encoder (
79
79
encoder_name , in_channels = in_channels , encoder_weights = None
80
- )
80
+ ). to ( default_device )
81
81
encoder .eval ()
82
82
83
83
# forward
@@ -90,7 +90,7 @@ def test_depth(self):
90
90
num_channels = self .default_num_channels ,
91
91
height = self .default_height ,
92
92
width = self .default_width ,
93
- )
93
+ ). to ( default_device )
94
94
95
95
cases = [
96
96
(encoder_name , depth )
@@ -105,7 +105,7 @@ def test_depth(self):
105
105
in_channels = self .default_num_channels ,
106
106
encoder_weights = None ,
107
107
depth = depth ,
108
- )
108
+ ). to ( default_device )
109
109
encoder .eval ()
110
110
111
111
# forward
@@ -154,7 +154,7 @@ def test_dilated(self):
154
154
num_channels = self .default_num_channels ,
155
155
height = self .default_height ,
156
156
width = self .default_width ,
157
- )
157
+ ). to ( default_device )
158
158
159
159
cases = [
160
160
(encoder_name , stride )
@@ -172,7 +172,7 @@ def test_dilated(self):
172
172
in_channels = self .default_num_channels ,
173
173
encoder_weights = None ,
174
174
output_stride = stride ,
175
- )
175
+ ). to ( default_device )
176
176
return
177
177
178
178
for encoder_name , stride in cases :
@@ -182,7 +182,7 @@ def test_dilated(self):
182
182
in_channels = self .default_num_channels ,
183
183
encoder_weights = None ,
184
184
output_stride = stride ,
185
- )
185
+ ). to ( default_device )
186
186
encoder .eval ()
187
187
188
188
# forward
0 commit comments