File tree Expand file tree Collapse file tree 3 files changed +5
-3
lines changed Expand file tree Collapse file tree 3 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -137,6 +137,7 @@ def test_classification_head(self):
137137
138138 self .assertEqual (cls_probs .shape [1 ], 10 )
139139
140+ @requires_torch_greater_or_equal ("2.0.0" )
140141 def test_save_load_with_hub_mixin (self ):
141142 # instantiate model
142143 model = smp .create_model (
@@ -172,7 +173,7 @@ def test_save_load_with_hub_mixin(self):
172173 self .assertIn ("my_awesome_metric" , readme )
173174
174175 @slow_test
175- @requires_torch_greater_or_equal ("2.0.1 " )
176+ @requires_torch_greater_or_equal ("2.0.0 " )
176177 def test_preserve_forward_output (self ):
177178 from huggingface_hub import hf_hub_download
178179
Original file line number Diff line number Diff line change 33import segmentation_models_pytorch as smp
44
55from tests .models import base
6- from tests .utils import slow_test , default_device
6+ from tests .utils import slow_test , default_device , requires_torch_greater_or_equal
77
88
99@pytest .mark .segformer
1010class TestSegformerModel (base .BaseModelTester ):
1111 test_model_type = "segformer"
1212
1313 @slow_test
14+ @requires_torch_greater_or_equal ("2.0.0" )
1415 def test_load_pretrained (self ):
1516 hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k"
1617
Original file line number Diff line number Diff line change @@ -42,6 +42,6 @@ def requires_torch_greater_or_equal(version: str):
4242 torch_version = Version (torch .__version__ )
4343 provided_version = Version (version )
4444 return unittest .skipUnless (
45- torch_version >= provided_version ,
45+ torch_version < provided_version ,
4646 f"torch version { torch_version } is less than { provided_version } " ,
4747 )
You can’t perform that action at this time.
0 commit comments