Skip to content

Commit aa2cf99

Browse files
committed
Add device
1 parent 4d51fac commit aa2cf99

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

tests/encoders/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import segmentation_models_pytorch as smp
44

55
from functools import lru_cache
6-
6+
from tests.utils import default_device
77

88
class BaseEncoderTester(unittest.TestCase):
99
encoder_names = []
@@ -40,13 +40,13 @@ def test_forward_backward(self):
4040
num_channels=self.default_num_channels,
4141
height=self.default_height,
4242
width=self.default_width,
43-
)
43+
).to(default_device)
4444
for encoder_name in self.encoder_names:
4545
with self.subTest(encoder_name=encoder_name):
4646
# init encoder
4747
encoder = smp.encoders.get_encoder(
4848
encoder_name, in_channels=3, encoder_weights=None
49-
)
49+
).to(default_device)
5050

5151
# forward
5252
features = encoder.forward(sample)
@@ -72,12 +72,12 @@ def test_in_channels(self):
7272
num_channels=in_channels,
7373
height=self.default_height,
7474
width=self.default_width,
75-
)
75+
).to(default_device)
7676

7777
with self.subTest(encoder_name=encoder_name, in_channels=in_channels):
7878
encoder = smp.encoders.get_encoder(
7979
encoder_name, in_channels=in_channels, encoder_weights=None
80-
)
80+
).to(default_device)
8181
encoder.eval()
8282

8383
# forward
@@ -90,7 +90,7 @@ def test_depth(self):
9090
num_channels=self.default_num_channels,
9191
height=self.default_height,
9292
width=self.default_width,
93-
)
93+
).to(default_device)
9494

9595
cases = [
9696
(encoder_name, depth)
@@ -105,7 +105,7 @@ def test_depth(self):
105105
in_channels=self.default_num_channels,
106106
encoder_weights=None,
107107
depth=depth,
108-
)
108+
).to(default_device)
109109
encoder.eval()
110110

111111
# forward
@@ -154,7 +154,7 @@ def test_dilated(self):
154154
num_channels=self.default_num_channels,
155155
height=self.default_height,
156156
width=self.default_width,
157-
)
157+
).to(default_device)
158158

159159
cases = [
160160
(encoder_name, stride)
@@ -172,7 +172,7 @@ def test_dilated(self):
172172
in_channels=self.default_num_channels,
173173
encoder_weights=None,
174174
output_stride=stride,
175-
)
175+
).to(default_device)
176176
return
177177

178178
for encoder_name, stride in cases:
@@ -182,7 +182,7 @@ def test_dilated(self):
182182
in_channels=self.default_num_channels,
183183
encoder_weights=None,
184184
output_stride=stride,
185-
)
185+
).to(default_device)
186186
encoder.eval()
187187

188188
# forward

0 commit comments

Comments
 (0)