77import torch
88import segmentation_models_pytorch as smp
99
10- from tests .config import has_timm_test_models
10+ from tests .utils import has_timm_test_models , slow_test
1111
1212
1313class BaseModelTester (unittest .TestCase ):
@@ -30,6 +30,10 @@ def model_type(self):
3030 raise ValueError ("test_model_type is not set" )
3131 return self .test_model_type
3232
33+ @property
34+ def hub_checkpoint (self ):
35+ return f"smp-test-models/{ self .model_type } -tu-resnet18"
36+
3337 @property
3438 def model_class (self ):
3539 return smp .MODEL_ARCHITECTURES_MAPPING [self .model_type ]
@@ -166,3 +170,27 @@ def test_save_load_with_hub_mixin(self):
166170 # check dataset and metrics are saved in readme
167171 self .assertIn ("test_dataset" , readme )
168172 self .assertIn ("my_awesome_metric" , readme )
173+
174+ @slow_test
175+ def test_preserve_forward_output (self ):
176+ from huggingface_hub import hf_hub_download
177+
178+ model = smp .from_pretrained (self .hub_checkpoint ).eval ()
179+
180+ input_tensor_path = hf_hub_download (
181+ repo_id = self .hub_checkpoint , filename = "input-tensor.pth"
182+ )
183+ output_tensor_path = hf_hub_download (
184+ repo_id = self .hub_checkpoint , filename = "output-tensor.pth"
185+ )
186+
187+ input_tensor = torch .load (input_tensor_path , weights_only = True )
188+ output_tensor = torch .load (output_tensor_path , weights_only = True )
189+
190+ with torch .no_grad ():
191+ output = model (input_tensor )
192+
193+ self .assertEqual (output .shape , output_tensor .shape )
194+ is_close = torch .allclose (output , output_tensor , atol = 1e-3 )
195+ max_diff = torch .max (torch .abs (output - output_tensor ))
196+ self .assertTrue (is_close , f"Max diff: { max_diff } " )
0 commit comments