Skip to content

Commit 1d5e1ea

Browse files
committed
Add base params test
1 parent 7c947f8 commit 1d5e1ea

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

tests/models/base.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import inspect
23
from functools import lru_cache
34

45
import torch
@@ -23,6 +24,18 @@ def model_type(self):
2324
raise ValueError("test_model_type is not set")
2425
return self.test_model_type
2526

27+
@property
28+
def model_class(self):
29+
return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type]
30+
31+
@property
32+
def decoder_channels(self):
33+
signature = inspect.signature(self.model_class)
34+
# check if decoder_channels is in the signature
35+
if "decoder_channels" in signature.parameters:
36+
return signature.parameters["decoder_channels"].default
37+
return None
38+
2639
@lru_cache
2740
def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32):
2841
return torch.rand(batch_size, num_channels, height, width)
@@ -50,3 +63,29 @@ def test_forward_backward(self):
5063

5164
# check backward pass
5265
output.mean().backward()
66+
67+
def test_base_params_are_set(self, in_channels=1, depth=3, classes=7):
68+
kwargs = {}
69+
70+
if self.model_type in ["unet", "unetplusplus", "manet"]:
71+
kwargs = {"decoder_channels": self.decoder_channels[:depth]}
72+
73+
model = smp.create_model(
74+
arch=self.model_type,
75+
encoder_depth=depth,
76+
in_channels=in_channels,
77+
classes=classes,
78+
**kwargs,
79+
)
80+
sample = self._get_sample(
81+
batch_size=self.default_batch_size,
82+
num_channels=in_channels,
83+
height=self.default_height,
84+
width=self.default_width,
85+
)
86+
87+
# check in channels correctly set
88+
with torch.no_grad():
89+
output = model(sample)
90+
91+
self.assertEqual(output.shape[1], classes)

0 commit comments

Comments
 (0)