Skip to content

Commit 8d06cba

Browse files
committed
Add test with hub checkpoint
1 parent 6346405 commit 8d06cba

9 files changed

+132
-6
lines changed

misc/generate_test_models.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import torch
3+
import tempfile
4+
import huggingface_hub
5+
import segmentation_models_pytorch as smp
6+
7+
HUB_REPO = "smp-test-models"
8+
ENCODER_NAME = "tu-resnet18"
9+
10+
api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN"))
11+
12+
for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
13+
model = model_class(encoder_name=ENCODER_NAME)
14+
model = model.eval()
15+
16+
# generate test sample
17+
torch.manual_seed(423553)
18+
sample = torch.rand(1, 3, 256, 256)
19+
20+
with torch.no_grad():
21+
output = model(sample)
22+
23+
with tempfile.TemporaryDirectory() as tmpdir:
24+
# save model
25+
model.save_pretrained(f"{tmpdir}")
26+
27+
# save input and output
28+
torch.save(sample, f"{tmpdir}/input-tensor.pth")
29+
torch.save(output, f"{tmpdir}/output-tensor.pth")
30+
31+
# create repo
32+
repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}"
33+
if not api.repo_exists(repo_id=repo_id):
34+
api.create_repo(repo_id=repo_id, repo_type="model")
35+
36+
# upload to hub
37+
api.upload_folder(
38+
folder_path=tmpdir,
39+
repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}",
40+
repo_type="model",
41+
)

tests/encoders/test_pretrainedmodels_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from tests.encoders import base
2-
from tests.config import RUN_ALL_ENCODERS
2+
from tests.utils import RUN_ALL_ENCODERS
33

44

55
class TestDenseNetEncoder(base.BaseEncoderTester):

tests/encoders/test_smp_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from tests.encoders import base
2-
from tests.config import RUN_ALL_ENCODERS
2+
from tests.utils import RUN_ALL_ENCODERS
33

44

55
class TestMobileoneEncoder(base.BaseEncoderTester):

tests/encoders/test_timm_ported_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from tests.encoders import base
2-
from tests.config import RUN_ALL_ENCODERS
2+
from tests.utils import RUN_ALL_ENCODERS
33

44

55
class TestTimmEfficientNetEncoder(base.BaseEncoderTester):

tests/encoders/test_timm_universal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from tests.encoders import base
2-
from tests.config import has_timm_test_models
2+
from tests.utils import has_timm_test_models
33

44
# check if timm >= 1.0.12
55
timm_encoders = [

tests/encoders/test_torchvision_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from tests.encoders import base
2-
from tests.config import RUN_ALL_ENCODERS
2+
from tests.utils import RUN_ALL_ENCODERS
33

44

55
class TestMobileoneEncoder(base.BaseEncoderTester):

tests/models/base.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import segmentation_models_pytorch as smp
99

10-
from tests.config import has_timm_test_models
10+
from tests.utils import has_timm_test_models, slow_test
1111

1212

1313
class BaseModelTester(unittest.TestCase):
@@ -30,6 +30,10 @@ def model_type(self):
3030
raise ValueError("test_model_type is not set")
3131
return self.test_model_type
3232

33+
@property
34+
def hub_checkpoint(self):
35+
return f"smp-test-models/{self.model_type}-tu-resnet18"
36+
3337
@property
3438
def model_class(self):
3539
return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type]
@@ -166,3 +170,27 @@ def test_save_load_with_hub_mixin(self):
166170
# check dataset and metrics are saved in readme
167171
self.assertIn("test_dataset", readme)
168172
self.assertIn("my_awesome_metric", readme)
173+
174+
@slow_test
175+
def test_preserve_forward_output(self):
176+
from huggingface_hub import hf_hub_download
177+
178+
model = smp.from_pretrained(self.hub_checkpoint).eval()
179+
180+
input_tensor_path = hf_hub_download(
181+
repo_id=self.hub_checkpoint, filename="input-tensor.pth"
182+
)
183+
output_tensor_path = hf_hub_download(
184+
repo_id=self.hub_checkpoint, filename="output-tensor.pth"
185+
)
186+
187+
input_tensor = torch.load(input_tensor_path, weights_only=True)
188+
output_tensor = torch.load(output_tensor_path, weights_only=True)
189+
190+
with torch.no_grad():
191+
output = model(input_tensor)
192+
193+
self.assertEqual(output.shape, output_tensor.shape)
194+
is_close = torch.allclose(output, output_tensor, atol=1e-3)
195+
max_diff = torch.max(torch.abs(output - output_tensor))
196+
self.assertTrue(is_close, f"Max diff: {max_diff}")

tests/models/test_segformer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,41 @@
1+
import torch
12
import pytest
3+
import segmentation_models_pytorch as smp
4+
25
from tests.models import base
6+
from tests.utils import slow_test, default_device
37

48

59
@pytest.mark.segformer
610
class TestSegformerModel(base.BaseModelTester):
711
test_model_type = "segformer"
12+
13+
@slow_test
14+
def test_load_pretrained(self):
15+
hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k"
16+
17+
model = smp.from_pretrained(hub_checkpoint)
18+
model = model.eval().to(default_device)
19+
20+
sample = torch.ones([1, 3, 512, 512]).to(default_device)
21+
22+
with torch.no_grad():
23+
output = model(sample)
24+
25+
self.assertEqual(output.shape, (1, 150, 512, 512))
26+
27+
expected_logits_slice = torch.tensor(
28+
[-4.4172, -4.4723, -4.5273, -4.5824, -4.6375, -4.7157]
29+
)
30+
resulted_logits_slice = output[0, 0, 256, :6].cpu()
31+
is_equal = torch.allclose(
32+
expected_logits_slice, resulted_logits_slice, atol=1e-2
33+
)
34+
max_diff = torch.max(torch.abs(expected_logits_slice - resulted_logits_slice))
35+
self.assertTrue(
36+
is_equal,
37+
f"Expected logits slice and resulted logits slice are not equal.\n"
38+
f"Max diff: {max_diff}\n"
39+
f"Expected: {expected_logits_slice}\n"
40+
f"Resulted: {resulted_logits_slice}\n",
41+
)

tests/config.py renamed to tests/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import os
22
import timm
3+
import torch
4+
import unittest
5+
36
from packaging.version import Version
47

58

69
has_timm_test_models = Version(timm.__version__) >= Version("1.0.12")
10+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
711

812

913
def get_commit_message():
@@ -28,3 +32,22 @@ def get_commit_message():
2832
os.getenv("RUN_SLOW", "false").lower() in ["true", "1", "y", "yes"]
2933
or "run-slow" in commit_message
3034
)
35+
36+
37+
def slow_test(test_case):
38+
"""
39+
Decorator marking a test as slow.
40+
41+
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
42+
43+
"""
44+
return unittest.skipUnless(RUN_SLOW, "test is slow")(test_case)
45+
46+
47+
def requires_torch_greater_or_equal(version: str):
48+
torch_version = Version(torch.__version__)
49+
provided_version = Version(version)
50+
return unittest.skipUnless(
51+
torch_version >= provided_version,
52+
f"torch version {torch_version} is less than {provided_version}",
53+
)

0 commit comments

Comments
 (0)