|
| 1 | +import unittest |
| 2 | +from functools import lru_cache |
| 3 | + |
| 4 | +import torch |
| 5 | +import segmentation_models_pytorch as smp |
| 6 | + |
| 7 | + |
| 8 | +class BaseModelTester(unittest.TestCase): |
| 9 | + test_encoder_name = "tu-test_resnet.r160_in1k" |
| 10 | + |
| 11 | + # should be overriden |
| 12 | + test_model_type = None |
| 13 | + |
| 14 | + # test sample configuration |
| 15 | + default_batch_size = 1 |
| 16 | + default_num_channels = 3 |
| 17 | + default_height = 64 |
| 18 | + default_width = 64 |
| 19 | + |
| 20 | + @property |
| 21 | + def model_type(self): |
| 22 | + if self.test_model_type is None: |
| 23 | + raise ValueError("test_model_type is not set") |
| 24 | + return self.test_model_type |
| 25 | + |
| 26 | + @lru_cache |
| 27 | + def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): |
| 28 | + return torch.rand(batch_size, num_channels, height, width) |
| 29 | + |
| 30 | + def test_forward_backward(self): |
| 31 | + sample = self._get_sample( |
| 32 | + batch_size=self.default_batch_size, |
| 33 | + num_channels=self.default_num_channels, |
| 34 | + height=self.default_height, |
| 35 | + width=self.default_width, |
| 36 | + ) |
| 37 | + model = smp.create_model(arch=self.model_type) |
| 38 | + |
| 39 | + # check default in_channels=3 |
| 40 | + output = model(sample) |
| 41 | + |
| 42 | + # check default output number of classes = 1 |
| 43 | + expected_number_of_classes = 1 |
| 44 | + result_number_of_classes = output.shape[1] |
| 45 | + self.assertEqual( |
| 46 | + result_number_of_classes, |
| 47 | + expected_number_of_classes, |
| 48 | + f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}", |
| 49 | + ) |
| 50 | + |
| 51 | + # check backward pass |
| 52 | + output.mean().backward() |
| 53 | + |
| 54 | + def test_encoder_params_are_set(self): |
| 55 | + model = smp.create_model(arch=self.model_type) |
| 56 | + self.assertEqual(model.encoder.name, self.test_encoder_name) |
0 commit comments