File tree Expand file tree Collapse file tree 3 files changed +5
-3
lines changed Expand file tree Collapse file tree 3 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -137,6 +137,7 @@ def test_classification_head(self):
137
137
138
138
self .assertEqual (cls_probs .shape [1 ], 10 )
139
139
140
+ @requires_torch_greater_or_equal ("2.0.0" )
140
141
def test_save_load_with_hub_mixin (self ):
141
142
# instantiate model
142
143
model = smp .create_model (
@@ -172,7 +173,7 @@ def test_save_load_with_hub_mixin(self):
172
173
self .assertIn ("my_awesome_metric" , readme )
173
174
174
175
@slow_test
175
- @requires_torch_greater_or_equal ("2.0.1 " )
176
+ @requires_torch_greater_or_equal ("2.0.0 " )
176
177
def test_preserve_forward_output (self ):
177
178
from huggingface_hub import hf_hub_download
178
179
Original file line number Diff line number Diff line change 3
3
import segmentation_models_pytorch as smp
4
4
5
5
from tests .models import base
6
- from tests .utils import slow_test , default_device
6
+ from tests .utils import slow_test , default_device , requires_torch_greater_or_equal
7
7
8
8
9
9
@pytest .mark .segformer
10
10
class TestSegformerModel (base .BaseModelTester ):
11
11
test_model_type = "segformer"
12
12
13
13
@slow_test
14
+ @requires_torch_greater_or_equal ("2.0.0" )
14
15
def test_load_pretrained (self ):
15
16
hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k"
16
17
Original file line number Diff line number Diff line change @@ -42,6 +42,6 @@ def requires_torch_greater_or_equal(version: str):
42
42
torch_version = Version (torch .__version__ )
43
43
provided_version = Version (version )
44
44
return unittest .skipUnless (
45
- torch_version >= provided_version ,
45
+ torch_version < provided_version ,
46
46
f"torch version { torch_version } is less than { provided_version } " ,
47
47
)
You can’t perform that action at this time.
0 commit comments