Skip to content

Commit 58afe1b

Browse files
committed
Add device
1 parent a17bb6a commit 58afe1b

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ test: .venv
1010
.venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match"
1111

1212
test_all: .venv
13-
.venv/bin/pytest -v -rsx -n 2 tests/
13+
RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/
1414

1515
table:
1616
.venv/bin/python misc/generate_table.py

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ docs = [
4040
]
4141
test = [
4242
'pytest',
43+
'pytest-cov',
44+
'pytest-xdist',
4345
'ruff',
4446
]
4547

tests/models/base.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
import torch
1010
import segmentation_models_pytorch as smp
1111

12-
from tests.utils import has_timm_test_models, slow_test, requires_torch_greater_or_equal
12+
from tests.utils import (
13+
has_timm_test_models,
14+
default_device,
15+
slow_test,
16+
requires_torch_greater_or_equal,
17+
)
1318

1419

1520
class BaseModelTester(unittest.TestCase):
@@ -58,10 +63,10 @@ def test_forward_backward(self):
5863
num_channels=self.default_num_channels,
5964
height=self.default_height,
6065
width=self.default_width,
61-
)
66+
).to(default_device)
6267
model = smp.create_model(
6368
arch=self.model_type, encoder_name=self.test_encoder_name
64-
)
69+
).to(default_device)
6570

6671
# check default in_channels=3
6772
output = model(sample)
@@ -93,13 +98,13 @@ def test_in_channels_and_depth_and_out_classes(
9398
in_channels=in_channels,
9499
classes=classes,
95100
**kwargs,
96-
)
101+
).to(default_device)
97102
sample = self._get_sample(
98103
batch_size=self.default_batch_size,
99104
num_channels=in_channels,
100105
height=self.default_height,
101106
width=self.default_width,
102-
)
107+
).to(default_device)
103108

104109
# check in channels correctly set
105110
with torch.no_grad():
@@ -117,7 +122,7 @@ def test_classification_head(self):
117122
"dropout": 0.5,
118123
"activation": "sigmoid",
119124
},
120-
)
125+
).to(default_device)
121126

122127
self.assertIsNotNone(model.classification_head)
123128
self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d)
@@ -132,7 +137,7 @@ def test_classification_head(self):
132137
num_channels=self.default_num_channels,
133138
height=self.default_height,
134139
width=self.default_width,
135-
)
140+
).to(default_device)
136141

137142
with torch.no_grad():
138143
_, cls_probs = model(sample)
@@ -144,14 +149,14 @@ def test_save_load_with_hub_mixin(self):
144149
# instantiate model
145150
model = smp.create_model(
146151
arch=self.model_type, encoder_name=self.test_encoder_name
147-
)
152+
).to(default_device)
148153

149154
# save model
150155
with tempfile.TemporaryDirectory() as tmpdir:
151156
model.save_pretrained(
152157
tmpdir, dataset="test_dataset", metrics={"my_awesome_metric": 0.99}
153158
)
154-
restored_model = smp.from_pretrained(tmpdir)
159+
restored_model = smp.from_pretrained(tmpdir).to(default_device)
155160
with open(os.path.join(tmpdir, "README.md"), "r") as f:
156161
readme = f.read()
157162

@@ -161,7 +166,7 @@ def test_save_load_with_hub_mixin(self):
161166
num_channels=self.default_num_channels,
162167
height=self.default_height,
163168
width=self.default_width,
164-
)
169+
).to(default_device)
165170

166171
with torch.no_grad():
167172
output = model(sample)
@@ -178,7 +183,7 @@ def test_save_load_with_hub_mixin(self):
178183
@requires_torch_greater_or_equal("2.0.1")
179184
@pytest.mark.logits_match
180185
def test_preserve_forward_output(self):
181-
model = smp.from_pretrained(self.hub_checkpoint).eval()
186+
model = smp.from_pretrained(self.hub_checkpoint).eval().to(default_device)
182187

183188
input_tensor_path = hf_hub_download(
184189
repo_id=self.hub_checkpoint, filename="input-tensor.pth"
@@ -188,12 +193,14 @@ def test_preserve_forward_output(self):
188193
)
189194

190195
input_tensor = torch.load(input_tensor_path, weights_only=True)
196+
input_tensor = input_tensor.to(default_device)
191197
output_tensor = torch.load(output_tensor_path, weights_only=True)
198+
output_tensor = output_tensor.to(default_device)
192199

193200
with torch.no_grad():
194201
output = model(input_tensor)
195202

196203
self.assertEqual(output.shape, output_tensor.shape)
197-
is_close = torch.allclose(output, output_tensor, atol=1e-3)
204+
is_close = torch.allclose(output, output_tensor, atol=1e-2)
198205
max_diff = torch.max(torch.abs(output - output_tensor))
199206
self.assertTrue(is_close, f"Max diff: {max_diff}")

0 commit comments

Comments
 (0)