Skip to content

Commit ead24b4

Browse files
committed
Refactor test do not skip
1 parent f70d861 commit ead24b4

File tree

9 files changed

+23
-28
lines changed

9 files changed

+23
-28
lines changed

segmentation_models_pytorch/base/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
class SegmentationModel(torch.nn.Module, SMPHubMixin):
1212
"""Base class for all segmentation models."""
1313

14+
_is_torch_scriptable = True
15+
_is_torch_exportable = True
16+
_is_torch_compilable = True
17+
1418
# if model supports shape not divisible by 2 ^ n set to False
1519
requires_divisible_input_shape = True
1620

segmentation_models_pytorch/decoders/unetplusplus/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class UnetPlusPlus(SegmentationModel):
5656
5757
"""
5858

59+
_is_torch_scriptable = False
60+
5961
@supports_config_loading
6062
def __init__(
6163
self,

segmentation_models_pytorch/decoders/upernet/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class UPerNet(SegmentationModel):
4848
4949
"""
5050

51+
_is_torch_scriptable = False
52+
5153
@supports_config_loading
5254
def __init__(
5355
self,

segmentation_models_pytorch/encoders/efficientnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
class EfficientNetEncoder(EfficientNet, EncoderMixin):
3636
_is_torch_scriptable = False
3737

38+
# works with torch 2.4.0, but not with torch 2.5.1
39+
_is_torch_compilable = False
40+
3841
def __init__(
3942
self,
4043
stage_idxs: List[int],

tests/encoders/base.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ class BaseEncoderTester(unittest.TestCase):
3333
depth_to_test = [3, 4, 5]
3434
strides_to_test = [8, 16] # 32 is a default one
3535

36-
# enable/disable tests
37-
do_test_torch_compile = True
38-
do_test_torch_export = True
39-
4036
def get_tiny_encoder(self):
4137
return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None)
4238

@@ -208,28 +204,25 @@ def test_dilated(self):
208204

209205
@pytest.mark.compile
210206
def test_compile(self):
211-
if not self.do_test_torch_compile:
212-
self.skipTest(
213-
f"torch_compile test is disabled for {self.encoder_names[0]}."
214-
)
215-
216207
if not check_run_test_on_diff_or_main(self.files_for_diff):
217208
self.skipTest("No diff and not on `main`.")
218209

219210
sample = self._get_sample().to(default_device)
220211

221-
encoder = self.get_tiny_encoder().eval().to(default_device)
212+
encoder = self.get_tiny_encoder()
213+
encoder = encoder.eval().to(default_device)
214+
222215
compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True)
223216

224-
with torch.inference_mode():
217+
if encoder._is_torch_compilable:
225218
compiled_encoder(sample)
219+
else:
220+
with self.assertRaises(Exception):
221+
compiled_encoder(sample)
226222

227223
@pytest.mark.torch_export
228224
@requires_torch_greater_or_equal("2.4.0")
229225
def test_torch_export(self):
230-
if not self.do_test_torch_export:
231-
self.skipTest(f"torch_export test is disabled for {self.encoder_names[0]}.")
232-
233226
if not check_run_test_on_diff_or_main(self.files_for_diff):
234227
self.skipTest("No diff and not on `main`.")
235228

tests/encoders/test_pretrainedmodels_encoders.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ class TestDPNEncoder(base.BaseEncoderTester):
1212
)
1313
files_for_diff = ["encoders/dpn.py"]
1414

15-
# works with torch 2.4.0, but not with torch 2.5.1
16-
# dynamo error, probably on Sequential + OrderedDict
17-
do_test_torch_export = False
18-
1915
def get_tiny_encoder(self):
2016
params = {
2117
"stage_idxs": [2, 3, 4, 6],
@@ -45,10 +41,6 @@ class TestInceptionV4Encoder(base.BaseEncoderTester):
4541
files_for_diff = ["encoders/inceptionv4.py"]
4642
supports_dilated = False
4743

48-
# works with torch 2.4.0, but not with torch 2.5.1
49-
# dynamo error, probably on Sequential + OrderedDict
50-
do_test_torch_export = False
51-
5244

5345
class TestSeNetEncoder(base.BaseEncoderTester):
5446
encoder_names = (

tests/encoders/test_smp_encoders.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,3 @@ class TestEfficientNetEncoder(base.BaseEncoderTester):
6262
]
6363
)
6464
files_for_diff = ["encoders/efficientnet.py"]
65-
66-
# torch_compile is not supported for efficientnet encoders
67-
do_test_torch_compile = False

tests/encoders/test_timm_ported_encoders.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ class TestTimmEfficientNetEncoder(base.BaseEncoderTester):
2626
)
2727
files_for_diff = ["encoders/timm_efficientnet.py"]
2828

29-
# works with torch 2.4.0, but not with torch 2.5.1
30-
do_test_torch_export = False
31-
3229

3330
class TestTimmGERNetEncoder(base.BaseEncoderTester):
3431
encoder_names = (

tests/models/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ def test_torch_script(self):
263263
model = self.get_default_model()
264264
model.eval()
265265

266+
if not model._is_torch_scriptable:
267+
with self.assertRaises(RuntimeError):
268+
scripted_model = torch.jit.script(model)
269+
return
270+
266271
scripted_model = torch.jit.script(model)
267272

268273
with torch.inference_mode():

0 commit comments

Comments
 (0)