|
1 |
| -import unittest |
2 | 1 | import inspect
|
| 2 | +import tempfile |
| 3 | +import unittest |
3 | 4 | from functools import lru_cache
|
4 | 5 |
|
5 | 6 | import torch
|
@@ -89,3 +90,58 @@ def test_base_params_are_set(self, in_channels=1, depth=3, classes=7):
|
89 | 90 | output = model(sample)
|
90 | 91 |
|
91 | 92 | self.assertEqual(output.shape[1], classes)
|
| 93 | + |
| 94 | + def test_aux_params(self): |
| 95 | + model = smp.create_model( |
| 96 | + arch=self.model_type, |
| 97 | + aux_params={ |
| 98 | + "pooling": "avg", |
| 99 | + "classes": 10, |
| 100 | + "dropout": 0.5, |
| 101 | + "activation": "sigmoid", |
| 102 | + }, |
| 103 | + ) |
| 104 | + |
| 105 | + self.assertIsNotNone(model.classification_head) |
| 106 | + self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d) |
| 107 | + self.assertIsInstance(model.classification_head[1], torch.nn.Flatten) |
| 108 | + self.assertIsInstance(model.classification_head[2], torch.nn.Dropout) |
| 109 | + self.assertEqual(model.classification_head[2].p, 0.5) |
| 110 | + self.assertIsInstance(model.classification_head[3], torch.nn.Linear) |
| 111 | + self.assertIsInstance(model.classification_head[4].activation, torch.nn.Sigmoid) |
| 112 | + |
| 113 | + sample = self._get_sample( |
| 114 | + batch_size=self.default_batch_size, |
| 115 | + num_channels=self.default_num_channels, |
| 116 | + height=self.default_height, |
| 117 | + width=self.default_width, |
| 118 | + ) |
| 119 | + |
| 120 | + with torch.no_grad(): |
| 121 | + _, cls_probs = model(sample) |
| 122 | + |
| 123 | + self.assertEqual(cls_probs.shape[1], 10) |
| 124 | + |
| 125 | + def test_save_load(self): |
| 126 | + # instantiate model |
| 127 | + model = smp.create_model(arch=self.model_type) |
| 128 | + |
| 129 | + # save model |
| 130 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 131 | + model.save_pretrained(tmpdir) |
| 132 | + restored_model = model.from_pretrained(tmpdir) |
| 133 | + |
| 134 | + # check inference is correct |
| 135 | + sample = self._get_sample( |
| 136 | + batch_size=self.default_batch_size, |
| 137 | + num_channels=self.default_num_channels, |
| 138 | + height=self.default_height, |
| 139 | + width=self.default_width, |
| 140 | + ) |
| 141 | + |
| 142 | + with torch.no_grad(): |
| 143 | + output = model(sample) |
| 144 | + restored_output = restored_model(sample) |
| 145 | + |
| 146 | + self.assertEqual(output.shape, restored_output.shape) |
| 147 | + self.assertEqual(output.shape[1], 1) |
0 commit comments