Skip to content

Commit ea54d93

Browse files
committed
Update test_models.py
1 parent 0e2da9e commit ea54d93

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

tests/test_models.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ def get_sample(model_class):
2929
smp.PSPNet,
3030
smp.UnetPlusPlus,
3131
smp.MAnet,
32-
smp.UPerNet,
3332
]:
3433
sample = torch.ones([1, 3, 64, 64])
35-
elif model_class == smp.PAN:
34+
elif model_class == smp.PAN or model_class == smp.UPerNet:
3635
sample = torch.ones([2, 3, 256, 256])
37-
elif model_class == smp.DeepLabV3:
36+
elif model_class == smp.DeepLabV3 or model_class == smp.DeepLabV3Plus:
3837
sample = torch.ones([2, 3, 128, 128])
3938
else:
4039
raise ValueError("Not supported model class {}".format(model_class))
@@ -102,6 +101,8 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
102101
smp.UnetPlusPlus,
103102
smp.MAnet,
104103
smp.DeepLabV3,
104+
smp.DeepLabV3Plus,
105+
smp.UPerNet,
105106
],
106107
)
107108
def test_forward_backward(model_class):
@@ -112,7 +113,18 @@ def test_forward_backward(model_class):
112113

113114
@pytest.mark.parametrize(
114115
"model_class",
115-
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet],
116+
[
117+
smp.PAN,
118+
smp.FPN,
119+
smp.PSPNet,
120+
smp.Linknet,
121+
smp.Unet,
122+
smp.UnetPlusPlus,
123+
smp.MAnet,
124+
smp.DeepLabV3,
125+
smp.DeepLabV3Plus,
126+
smp.UPerNet,
127+
],
116128
)
117129
def test_aux_output(model_class):
118130
model = model_class(

0 commit comments

Comments
 (0)