Skip to content

Commit 50c40d1

Browse files
committed
Add triggering compile tests on diff
1 parent 8aed7ef commit 50c40d1

File tree

9 files changed

+56
-20
lines changed

9 files changed

+56
-20
lines changed

requirements/test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
gitpython==3.1.44
12
packaging==24.2
23
pytest==8.3.4
34
pytest-xdist==3.6.1
45
pytest-cov==6.0.0
5-
ruff==0.9.1
6+
ruff==0.9.1

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ def pytest_addoption(parser):
33
"--non-marked-only", action="store_true", help="Run only non-marked tests"
44
)
55

6+
67
def pytest_collection_modifyitems(config, items):
78
if config.getoption("--non-marked-only"):
89
non_marked_items = []
910
for item in items:
1011
# Check if the test has no marks
1112
if not item.own_markers:
1213
non_marked_items.append(item)
13-
14+
1415
# Update the test collection to only include non-marked tests
1516
items[:] = non_marked_items

tests/encoders/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import segmentation_models_pytorch as smp
55

66
from functools import lru_cache
7-
from tests.utils import default_device
7+
from tests.utils import default_device, check_run_test_on_diff_or_main
88

99

1010
class BaseEncoderTester(unittest.TestCase):
1111
encoder_names = []
1212

13+
# some tests might be slow, running them only on diff
14+
files_for_diff = []
15+
1316
# standard encoder configuration
1417
num_output_features = 6
1518
output_strides = [1, 2, 4, 8, 16, 32]
@@ -213,6 +216,9 @@ def test_dilated(self):
213216

214217
@pytest.mark.compile
215218
def test_compile(self):
219+
if not check_run_test_on_diff_or_main(self.files_for_diff):
220+
self.skipTest("No diff and not on `main`.")
221+
216222
sample = self._get_sample(
217223
batch_size=self.default_batch_size,
218224
num_channels=self.default_num_channels,

tests/encoders/test_pretrainedmodels_encoders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class TestDPNEncoder(base.BaseEncoderTester):
1010
if not RUN_ALL_ENCODERS
1111
else ["dpn68", "dpn68b", "dpn92", "dpn98", "dpn107", "dpn131"]
1212
)
13+
files_for_diff = ["encoders/dpn.py"]
1314

1415
def get_tiny_encoder(self):
1516
params = {
@@ -32,11 +33,13 @@ class TestInceptionResNetV2Encoder(base.BaseEncoderTester):
3233
encoder_names = (
3334
["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"]
3435
)
36+
files_for_diff = ["encoders/inceptionresnetv2.py"]
3537

3638

3739
class TestInceptionV4Encoder(base.BaseEncoderTester):
3840
supports_dilated = False
3941
encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"]
42+
files_for_diff = ["encoders/inceptionv4.py"]
4043

4144

4245
class TestSeNetEncoder(base.BaseEncoderTester):
@@ -52,6 +55,7 @@ class TestSeNetEncoder(base.BaseEncoderTester):
5255
# "senet154", # extra large model
5356
]
5457
)
58+
files_for_diff = ["encoders/senet.py"]
5559

5660
def get_tiny_encoder(self):
5761
params = {
@@ -73,3 +77,4 @@ def get_tiny_encoder(self):
7377
class TestXceptionEncoder(base.BaseEncoderTester):
7478
supports_dilated = False
7579
encoder_names = ["xception"] if not RUN_ALL_ENCODERS else ["xception"]
80+
files_for_diff = ["encoders/xception.py"]

tests/encoders/test_smp_encoders.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class TestMobileoneEncoder(base.BaseEncoderTester):
1717
"mobileone_s4",
1818
]
1919
)
20+
files_for_diff = ["encoders/mobileone.py"]
2021

2122

2223
class TestMixTransformerEncoder(base.BaseEncoderTester):
@@ -25,6 +26,7 @@ class TestMixTransformerEncoder(base.BaseEncoderTester):
2526
if not RUN_ALL_ENCODERS
2627
else ["mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5"]
2728
)
29+
files_for_diff = ["encoders/mix_transformer.py"]
2830

2931
def get_tiny_encoder(self):
3032
params = {
@@ -59,6 +61,7 @@ class TestEfficientNetEncoder(base.BaseEncoderTester):
5961
# "efficientnet-b7", # extra large model
6062
]
6163
)
64+
files_for_diff = ["encoders/efficientnet.py"]
6265

6366
def test_compile(self):
6467
self.skipTest("compile fullgraph is not supported for efficientnet encoders")

tests/encoders/test_timm_ported_encoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TestTimmEfficientNetEncoder(base.BaseEncoderTester):
2424
"timm-tf_efficientnet_lite4",
2525
]
2626
)
27+
files_for_diff = ["encoders/timm_efficientnet.py"]
2728

2829

2930
class TestTimmGERNetEncoder(base.BaseEncoderTester):
@@ -144,3 +145,4 @@ class TestTimmSkNetEncoder(base.BaseEncoderTester):
144145
"timm-skresnext50_32x4d",
145146
]
146147
)
148+
files_for_diff = ["encoders/timm_sknet.py"]

tests/encoders/test_timm_universal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
class TestTimmUniversalEncoder(base.BaseEncoderTester):
1616
encoder_names = timm_encoders
17+
files_for_diff = ["encoders/timm_universal.py"]

tests/encoders/test_torchvision_encoders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class TestResNetEncoder(base.BaseEncoderTester):
2222
"resnext101_32x48d",
2323
]
2424
)
25+
files_for_diff = ["encoders/resnet.py"]
2526

2627
def get_tiny_encoder(self):
2728
params = {
@@ -39,6 +40,7 @@ class TestDenseNetEncoder(base.BaseEncoderTester):
3940
if not RUN_ALL_ENCODERS
4041
else ["densenet121", "densenet169", "densenet161"]
4142
)
43+
files_for_diff = ["encoders/densenet.py"]
4244

4345
def get_tiny_encoder(self):
4446
params = {
@@ -52,6 +54,7 @@ def get_tiny_encoder(self):
5254

5355
class TestMobileNetEncoder(base.BaseEncoderTester):
5456
encoder_names = ["mobilenet_v2"] if not RUN_ALL_ENCODERS else ["mobilenet_v2"]
57+
files_for_diff = ["encoders/mobilenet.py"]
5558

5659

5760
class TestVggEncoder(base.BaseEncoderTester):
@@ -70,6 +73,7 @@ class TestVggEncoder(base.BaseEncoderTester):
7073
"vgg19_bn",
7174
]
7275
)
76+
files_for_diff = ["encoders/vgg.py"]
7377

7478
def get_tiny_encoder(self):
7579
params = {

tests/utils.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,21 @@
11
import os
2+
import re
23
import timm
34
import torch
45
import unittest
56

7+
from git import Repo
8+
from typing import List
69
from packaging.version import Version
710

811

912
has_timm_test_models = Version(timm.__version__) >= Version("1.0.12")
1013
default_device = "cuda" if torch.cuda.is_available() else "cpu"
1114

12-
13-
def get_commit_message():
14-
commit_msg = os.getenv("COMMIT_MESSAGE", "")
15-
return commit_msg.lower()
16-
17-
18-
# Check both environment variables and commit message
19-
commit_message = get_commit_message()
20-
RUN_ALL_ENCODERS = (
21-
os.getenv("RUN_ALL_ENCODERS", "false").lower() in ["true", "1", "y", "yes"]
22-
or "run-all-encoders" in commit_message
23-
)
24-
25-
RUN_SLOW = (
26-
os.getenv("RUN_SLOW", "false").lower() in ["true", "1", "y", "yes"]
27-
or "run-slow" in commit_message
28-
)
15+
YES_LIST = ["true", "1", "y", "yes"]
16+
RUN_ALL_ENCODERS = os.getenv("RUN_ALL_ENCODERS", "false").lower() in YES_LIST
17+
RUN_SLOW = os.getenv("RUN_SLOW", "false").lower() in YES_LIST
18+
RUN_ALL = os.getenv("RUN_ALL", "false").lower() in YES_LIST
2919

3020

3121
def slow_test(test_case):
@@ -45,3 +35,26 @@ def requires_torch_greater_or_equal(version: str):
4535
torch_version >= provided_version,
4636
f"torch version {torch_version} is less than {provided_version}",
4737
)
38+
39+
40+
def check_run_test_on_diff_or_main(filepath_patterns: List[str]):
41+
if RUN_ALL:
42+
return True
43+
44+
try:
45+
repo = Repo(".")
46+
current_branch = repo.active_branch.name
47+
diff_files = repo.git.diff("main", name_only=True).splitlines()
48+
49+
except Exception:
50+
return True
51+
52+
if current_branch == "main":
53+
return True
54+
55+
for pattern in filepath_patterns:
56+
for file_path in diff_files:
57+
if re.search(pattern, file_path):
58+
return True
59+
60+
return False

0 commit comments

Comments
 (0)