99import torch
1010import segmentation_models_pytorch as smp
1111
12- from tests .utils import has_timm_test_models , slow_test , requires_torch_greater_or_equal
12+ from tests .utils import (
13+ has_timm_test_models ,
14+ default_device ,
15+ slow_test ,
16+ requires_torch_greater_or_equal ,
17+ )
1318
1419
1520class BaseModelTester (unittest .TestCase ):
@@ -58,10 +63,10 @@ def test_forward_backward(self):
5863 num_channels = self .default_num_channels ,
5964 height = self .default_height ,
6065 width = self .default_width ,
61- )
66+ ). to ( default_device )
6267 model = smp .create_model (
6368 arch = self .model_type , encoder_name = self .test_encoder_name
64- )
69+ ). to ( default_device )
6570
6671 # check default in_channels=3
6772 output = model (sample )
@@ -93,13 +98,13 @@ def test_in_channels_and_depth_and_out_classes(
9398 in_channels = in_channels ,
9499 classes = classes ,
95100 ** kwargs ,
96- )
101+ ). to ( default_device )
97102 sample = self ._get_sample (
98103 batch_size = self .default_batch_size ,
99104 num_channels = in_channels ,
100105 height = self .default_height ,
101106 width = self .default_width ,
102- )
107+ ). to ( default_device )
103108
104109 # check in channels correctly set
105110 with torch .no_grad ():
@@ -117,7 +122,7 @@ def test_classification_head(self):
117122 "dropout" : 0.5 ,
118123 "activation" : "sigmoid" ,
119124 },
120- )
125+ ). to ( default_device )
121126
122127 self .assertIsNotNone (model .classification_head )
123128 self .assertIsInstance (model .classification_head [0 ], torch .nn .AdaptiveAvgPool2d )
@@ -132,7 +137,7 @@ def test_classification_head(self):
132137 num_channels = self .default_num_channels ,
133138 height = self .default_height ,
134139 width = self .default_width ,
135- )
140+ ). to ( default_device )
136141
137142 with torch .no_grad ():
138143 _ , cls_probs = model (sample )
@@ -144,14 +149,14 @@ def test_save_load_with_hub_mixin(self):
144149 # instantiate model
145150 model = smp .create_model (
146151 arch = self .model_type , encoder_name = self .test_encoder_name
147- )
152+ ). to ( default_device )
148153
149154 # save model
150155 with tempfile .TemporaryDirectory () as tmpdir :
151156 model .save_pretrained (
152157 tmpdir , dataset = "test_dataset" , metrics = {"my_awesome_metric" : 0.99 }
153158 )
154- restored_model = smp .from_pretrained (tmpdir )
159+ restored_model = smp .from_pretrained (tmpdir ). to ( default_device )
155160 with open (os .path .join (tmpdir , "README.md" ), "r" ) as f :
156161 readme = f .read ()
157162
@@ -161,7 +166,7 @@ def test_save_load_with_hub_mixin(self):
161166 num_channels = self .default_num_channels ,
162167 height = self .default_height ,
163168 width = self .default_width ,
164- )
169+ ). to ( default_device )
165170
166171 with torch .no_grad ():
167172 output = model (sample )
@@ -178,7 +183,7 @@ def test_save_load_with_hub_mixin(self):
178183 @requires_torch_greater_or_equal ("2.0.1" )
179184 @pytest .mark .logits_match
180185 def test_preserve_forward_output (self ):
181- model = smp .from_pretrained (self .hub_checkpoint ).eval ()
186+ model = smp .from_pretrained (self .hub_checkpoint ).eval (). to ( default_device )
182187
183188 input_tensor_path = hf_hub_download (
184189 repo_id = self .hub_checkpoint , filename = "input-tensor.pth"
@@ -188,12 +193,14 @@ def test_preserve_forward_output(self):
188193 )
189194
190195 input_tensor = torch .load (input_tensor_path , weights_only = True )
196+ input_tensor = input_tensor .to (default_device )
191197 output_tensor = torch .load (output_tensor_path , weights_only = True )
198+ output_tensor = output_tensor .to (default_device )
192199
193200 with torch .no_grad ():
194201 output = model (input_tensor )
195202
196203 self .assertEqual (output .shape , output_tensor .shape )
197- is_close = torch .allclose (output , output_tensor , atol = 1e-3 )
204+ is_close = torch .allclose (output , output_tensor , atol = 1e-2 )
198205 max_diff = torch .max (torch .abs (output - output_tensor ))
199206 self .assertTrue (is_close , f"Max diff: { max_diff } " )
0 commit comments