Skip to content

Commit b7ce422

Browse files
committed
Add save-load test, add aux head test
1 parent 2b113f0 commit b7ce422

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

tests/models/base.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import unittest
21
import inspect
2+
import tempfile
3+
import unittest
34
from functools import lru_cache
45

56
import torch
@@ -89,3 +90,58 @@ def test_base_params_are_set(self, in_channels=1, depth=3, classes=7):
8990
output = model(sample)
9091

9192
self.assertEqual(output.shape[1], classes)
93+
94+
def test_aux_params(self):
95+
model = smp.create_model(
96+
arch=self.model_type,
97+
aux_params={
98+
"pooling": "avg",
99+
"classes": 10,
100+
"dropout": 0.5,
101+
"activation": "sigmoid",
102+
},
103+
)
104+
105+
self.assertIsNotNone(model.classification_head)
106+
self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d)
107+
self.assertIsInstance(model.classification_head[1], torch.nn.Flatten)
108+
self.assertIsInstance(model.classification_head[2], torch.nn.Dropout)
109+
self.assertEqual(model.classification_head[2].p, 0.5)
110+
self.assertIsInstance(model.classification_head[3], torch.nn.Linear)
111+
self.assertIsInstance(model.classification_head[4].activation, torch.nn.Sigmoid)
112+
113+
sample = self._get_sample(
114+
batch_size=self.default_batch_size,
115+
num_channels=self.default_num_channels,
116+
height=self.default_height,
117+
width=self.default_width,
118+
)
119+
120+
with torch.no_grad():
121+
_, cls_probs = model(sample)
122+
123+
self.assertEqual(cls_probs.shape[1], 10)
124+
125+
def test_save_load(self):
126+
# instantiate model
127+
model = smp.create_model(arch=self.model_type)
128+
129+
# save model
130+
with tempfile.TemporaryDirectory() as tmpdir:
131+
model.save_pretrained(tmpdir)
132+
restored_model = model.from_pretrained(tmpdir)
133+
134+
# check inference is correct
135+
sample = self._get_sample(
136+
batch_size=self.default_batch_size,
137+
num_channels=self.default_num_channels,
138+
height=self.default_height,
139+
width=self.default_width,
140+
)
141+
142+
with torch.no_grad():
143+
output = model(sample)
144+
restored_output = restored_model(sample)
145+
146+
self.assertEqual(output.shape, restored_output.shape)
147+
self.assertEqual(output.shape[1], 1)

0 commit comments

Comments
 (0)