Skip to content

Commit 8df56dd

Browse files
committed
Update test_models.py
1 parent ea54d93 commit 8df56dd

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

tests/test_models.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,21 @@ def get_encoders():
2323

2424
def get_sample(model_class):
2525
if model_class in [
26-
smp.Unet,
27-
smp.Linknet,
2826
smp.FPN,
29-
smp.PSPNet,
27+
smp.Linknet,
28+
smp.Unet,
3029
smp.UnetPlusPlus,
3130
smp.MAnet,
3231
]:
3332
sample = torch.ones([1, 3, 64, 64])
34-
elif model_class == smp.PAN or model_class == smp.UPerNet:
33+
elif model_class == smp.PAN:
3534
sample = torch.ones([2, 3, 256, 256])
36-
elif model_class == smp.DeepLabV3 or model_class == smp.DeepLabV3Plus:
35+
elif model_class in [smp.DeepLabV3, smp.DeepLabV3Plus]:
3736
sample = torch.ones([2, 3, 128, 128])
37+
elif model_class in [smp.PSPNet, smp.UPerNet]:
38+
# Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input
39+
# from PSPModule pooling in PSPNet/UPerNet.
40+
sample = torch.ones([2, 3, 64, 64])
3841
else:
3942
raise ValueError("Not supported model class {}".format(model_class))
4043
return sample

0 commit comments

Comments
 (0)