Skip to content

Commit 61a5496

Browse files
committed
Basic model test
1 parent 0474e77 commit 61a5496

14 files changed

+159
-0
lines changed

pyproject.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,18 @@ version = {attr = 'segmentation_models_pytorch.__version__.__version__'}
5555

5656
[tool.setuptools.packages.find]
5757
include = ['segmentation_models_pytorch*']
58+
59+
[tool.pytest.ini_options]
60+
markers = [
61+
"deeplabv3",
62+
"deeplabv3plus",
63+
"fpn",
64+
"linknet",
65+
"manet",
66+
"pan",
67+
"psp",
68+
"segformer",
69+
"unet",
70+
"unetplusplus",
71+
"upernet",
72+
]

tests/encoders/test_timm_universal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
class TestTimmUniversalEncoder(base.BaseEncoderTester):
55
encoder_names = [
6+
"tu-test_resnet.r160_in1k",
67
"tu-resnet18", # for timm universal traditional-like encoder
78
"tu-convnext_atto", # for timm universal transformer-like encoder
89
"tu-darknet17", # for timm universal vgg-like encoder

tests/models/__init__.py

Whitespace-only changes.

tests/models/base.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
from functools import lru_cache
3+
4+
import torch
5+
import segmentation_models_pytorch as smp
6+
7+
8+
class BaseModelTester(unittest.TestCase):
9+
test_encoder_name = "tu-test_resnet.r160_in1k"
10+
11+
# should be overriden
12+
test_model_type = None
13+
14+
# test sample configuration
15+
default_batch_size = 1
16+
default_num_channels = 3
17+
default_height = 64
18+
default_width = 64
19+
20+
@property
21+
def model_type(self):
22+
if self.test_model_type is None:
23+
raise ValueError("test_model_type is not set")
24+
return self.test_model_type
25+
26+
@lru_cache
27+
def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32):
28+
return torch.rand(batch_size, num_channels, height, width)
29+
30+
def test_forward_backward(self):
31+
sample = self._get_sample(
32+
batch_size=self.default_batch_size,
33+
num_channels=self.default_num_channels,
34+
height=self.default_height,
35+
width=self.default_width,
36+
)
37+
model = smp.create_model(arch=self.model_type)
38+
39+
# check default in_channels=3
40+
output = model(sample)
41+
42+
# check default output number of classes = 1
43+
expected_number_of_classes = 1
44+
result_number_of_classes = output.shape[1]
45+
self.assertEqual(
46+
result_number_of_classes,
47+
expected_number_of_classes,
48+
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}",
49+
)
50+
51+
# check backward pass
52+
output.mean().backward()
53+
54+
def test_encoder_params_are_set(self):
55+
model = smp.create_model(arch=self.model_type)
56+
self.assertEqual(model.encoder.name, self.test_encoder_name)

tests/models/test_deeplab.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
from tests.models import base
3+
4+
5+
@pytest.mark.deeplabv3
6+
class TestDeeplabV3Model(base.BaseModelTester):
7+
test_model_type = "deeplabv3"
8+
9+
default_batch_size = 2
10+
11+
12+
@pytest.mark.deeplabv3plus
13+
class TestDeeplabV3PlusModel(base.BaseModelTester):
14+
test_model_type = "deeplabv3plus"
15+
16+
default_batch_size = 2

tests/models/test_fpn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
from tests.models import base
3+
4+
5+
@pytest.mark.fpn
6+
class TestFpnModel(base.BaseModelTester):
7+
test_model_type = "fpn"

tests/models/test_linknet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
from tests.models import base
3+
4+
5+
@pytest.mark.linknet
6+
class TestLinknetModel(base.BaseModelTester):
7+
test_model_type = "linknet"

tests/models/test_manet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
from tests.models import base
3+
4+
5+
@pytest.mark.manet
6+
class TestManetModel(base.BaseModelTester):
7+
test_model_type = "manet"

tests/models/test_pan.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import pytest
2+
from tests.models import base
3+
4+
5+
@pytest.mark.pan
6+
class TestPanModel(base.BaseModelTester):
7+
test_model_type = "pan"
8+
test_encoder_name = "resnet-18"
9+
10+
default_batch_size = 2
11+
default_height = 128
12+
default_width = 128

tests/models/test_psp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import pytest
2+
from tests.models import base
3+
4+
5+
@pytest.mark.psp
6+
class TestPspModel(base.BaseModelTester):
7+
test_model_type = "pspnet"
8+
9+
default_batch_size = 2

0 commit comments

Comments
 (0)