Skip to content

Commit ff278c9

Browse files
committed
Add compile test for models
1 parent ae3cb8a commit ff278c9

File tree

11 files changed

+32
-0
lines changed

11 files changed

+32
-0
lines changed

tests/models/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
default_device,
1515
slow_test,
1616
requires_torch_greater_or_equal,
17+
check_run_test_on_diff_or_main,
1718
)
1819

1920

2021
class BaseModelTester(unittest.TestCase):
2122
test_encoder_name = (
2223
"tu-test_resnet.r160_in1k" if has_timm_test_models else "resnet18"
2324
)
25+
files_for_diff = [r".*"]
2426

2527
# should be overriden
2628
test_model_type = None
@@ -234,3 +236,21 @@ def test_preserve_forward_output(self):
234236
is_close = torch.allclose(output, output_tensor, atol=5e-2)
235237
max_diff = torch.max(torch.abs(output - output_tensor))
236238
self.assertTrue(is_close, f"Max diff: {max_diff}")
239+
240+
@pytest.mark.compile
241+
def test_compile(self):
242+
if not check_run_test_on_diff_or_main(self.files_for_diff):
243+
self.skipTest("No diff and not on `main`.")
244+
245+
sample = self._get_sample(
246+
batch_size=self.default_batch_size,
247+
num_channels=self.default_num_channels,
248+
height=self.default_height,
249+
width=self.default_width,
250+
).to(default_device)
251+
252+
model = self.get_default_model()
253+
compiled_model = torch.compile(model, fullgraph=True, dynamic=True)
254+
255+
with torch.inference_mode():
256+
compiled_model(sample)

tests/models/test_deeplab.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
class TestDeeplabV3Model(base.BaseModelTester):
55
test_model_type = "deeplabv3"
6+
files_for_diff = [r"decoders/deeplabv3/", r"base/"]
67

78
default_batch_size = 2
89

910

1011
class TestDeeplabV3PlusModel(base.BaseModelTester):
1112
test_model_type = "deeplabv3plus"
13+
files_for_diff = [r"decoders/deeplabv3plus/", r"base/"]
1214

1315
default_batch_size = 2

tests/models/test_fpn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
class TestFpnModel(base.BaseModelTester):
55
test_model_type = "fpn"
6+
files_for_diff = [r"decoders/fpn/", r"base/"]

tests/models/test_linknet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
class TestLinknetModel(base.BaseModelTester):
55
test_model_type = "linknet"
6+
files_for_diff = [r"decoders/linknet/", r"base/"]

tests/models/test_manet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
class TestManetModel(base.BaseModelTester):
55
test_model_type = "manet"
6+
files_for_diff = [r"decoders/manet/", r"base/"]

tests/models/test_pan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
class TestPanModel(base.BaseModelTester):
55
test_model_type = "pan"
6+
files_for_diff = [r"decoders/pan/", r"base/"]
67

78
default_batch_size = 2
89
default_height = 128

tests/models/test_psp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33

44
class TestPspModel(base.BaseModelTester):
55
test_model_type = "pspnet"
6+
files_for_diff = [r"decoders/pspnet/", r"base/"]
67

78
default_batch_size = 2

tests/models/test_segformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class TestSegformerModel(base.BaseModelTester):
1010
test_model_type = "segformer"
11+
files_for_diff = [r"decoders/segformer/", r"base/"]
1112

1213
@slow_test
1314
@requires_torch_greater_or_equal("2.0.1")

tests/models/test_unet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
class TestUnetModel(base.BaseModelTester):
55
test_model_type = "unet"
6+
files_for_diff = [r"decoders/unet/", r"base/"]

tests/models/test_unetplusplus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
class TestUnetPlusPlusModel(base.BaseModelTester):
55
test_model_type = "unetplusplus"
6+
files_for_diff = [r"decoders/unetplusplus/", r"base/"]

0 commit comments

Comments
 (0)