Skip to content

Commit ae0977e

Browse files
committed
Disable models tests
1 parent 988c85f commit ae0977e

File tree

1 file changed

+194
-194
lines changed

1 file changed

+194
-194
lines changed

tests/test_models.py

Lines changed: 194 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,194 +1,194 @@
1-
import pytest
2-
import torch
3-
4-
import segmentation_models_pytorch as smp # noqa
5-
6-
7-
def get_encoders():
8-
exclude_encoders = [
9-
"senet154",
10-
"resnext101_32x16d",
11-
"resnext101_32x32d",
12-
"resnext101_32x48d",
13-
]
14-
encoders = smp.encoders.get_encoder_names()
15-
encoders = [e for e in encoders if e not in exclude_encoders]
16-
encoders.append("tu-resnet34") # for timm universal traditional-like encoder
17-
encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder
18-
encoders.append("tu-darknet17") # for timm universal vgg-like encoder
19-
return encoders
20-
21-
22-
ENCODERS = get_encoders()
23-
DEFAULT_ENCODER = "resnet18"
24-
25-
26-
def get_sample(model_class):
27-
if model_class in [
28-
smp.FPN,
29-
smp.Linknet,
30-
smp.Unet,
31-
smp.UnetPlusPlus,
32-
smp.MAnet,
33-
smp.Segformer,
34-
]:
35-
sample = torch.ones([1, 3, 64, 64])
36-
elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]:
37-
sample = torch.ones([2, 3, 128, 128])
38-
elif model_class in [smp.PSPNet, smp.UPerNet]:
39-
# Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input
40-
# from PSPModule pooling in PSPNet/UPerNet.
41-
sample = torch.ones([2, 3, 64, 64])
42-
else:
43-
raise ValueError("Not supported model class {}".format(model_class))
44-
return sample
45-
46-
47-
def _test_forward(model, sample, test_shape=False):
48-
with torch.no_grad():
49-
out = model(sample)
50-
if test_shape:
51-
assert out.shape[2:] == sample.shape[2:]
52-
53-
54-
def _test_forward_backward(model, sample, test_shape=False):
55-
out = model(sample)
56-
out.mean().backward()
57-
if test_shape:
58-
assert out.shape[2:] == sample.shape[2:]
59-
60-
61-
@pytest.mark.parametrize("encoder_name", ENCODERS)
62-
@pytest.mark.parametrize("encoder_depth", [3, 5])
63-
@pytest.mark.parametrize(
64-
"model_class",
65-
[
66-
smp.FPN,
67-
smp.PSPNet,
68-
smp.Linknet,
69-
smp.Unet,
70-
smp.UnetPlusPlus,
71-
smp.MAnet,
72-
smp.UPerNet,
73-
smp.Segformer,
74-
],
75-
)
76-
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
77-
if (
78-
model_class is smp.Unet
79-
or model_class is smp.UnetPlusPlus
80-
or model_class is smp.MAnet
81-
):
82-
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
83-
if model_class in [smp.UnetPlusPlus, smp.Linknet]:
84-
if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"):
85-
return # skip transformer-like model*
86-
if model_class is smp.FPN and encoder_depth != 5:
87-
if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"):
88-
return # skip transformer-like model*
89-
model = model_class(
90-
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
91-
)
92-
sample = get_sample(model_class)
93-
model.eval()
94-
if encoder_depth == 5 and model_class != smp.PSPNet:
95-
test_shape = True
96-
else:
97-
test_shape = False
98-
99-
_test_forward(model, sample, test_shape)
100-
101-
102-
@pytest.mark.parametrize(
103-
"model_class",
104-
[
105-
smp.PAN,
106-
smp.FPN,
107-
smp.PSPNet,
108-
smp.Linknet,
109-
smp.Unet,
110-
smp.UnetPlusPlus,
111-
smp.MAnet,
112-
smp.DeepLabV3,
113-
smp.DeepLabV3Plus,
114-
smp.UPerNet,
115-
smp.Segformer,
116-
],
117-
)
118-
def test_forward_backward(model_class):
119-
sample = get_sample(model_class)
120-
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
121-
_test_forward_backward(model, sample)
122-
123-
124-
@pytest.mark.parametrize(
125-
"model_class",
126-
[
127-
smp.PAN,
128-
smp.FPN,
129-
smp.PSPNet,
130-
smp.Linknet,
131-
smp.Unet,
132-
smp.UnetPlusPlus,
133-
smp.MAnet,
134-
smp.DeepLabV3,
135-
smp.DeepLabV3Plus,
136-
smp.UPerNet,
137-
smp.Segformer,
138-
],
139-
)
140-
def test_aux_output(model_class):
141-
model = model_class(
142-
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)
143-
)
144-
sample = get_sample(model_class)
145-
label_size = (sample.shape[0], 2)
146-
mask, label = model(sample)
147-
assert label.size() == label_size
148-
149-
150-
@pytest.mark.parametrize("upsampling", [2, 4, 8])
151-
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet])
152-
def test_upsample(model_class, upsampling):
153-
default_upsampling = 4 if model_class is smp.FPN else 8
154-
model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling)
155-
sample = get_sample(model_class)
156-
mask = model(sample)
157-
assert mask.size()[-1] / 64 == upsampling / default_upsampling
158-
159-
160-
@pytest.mark.parametrize("model_class", [smp.FPN])
161-
@pytest.mark.parametrize("in_channels", [1, 2, 4])
162-
def test_in_channels(model_class, in_channels):
163-
sample = torch.ones([1, in_channels, 64, 64])
164-
model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels)
165-
model.eval()
166-
with torch.no_grad():
167-
model(sample)
168-
169-
assert model.encoder._in_channels == in_channels
170-
171-
172-
@pytest.mark.parametrize("encoder_name", ENCODERS)
173-
def test_dilation(encoder_name):
174-
if (
175-
encoder_name in ["inceptionresnetv2", "xception", "inceptionv4"]
176-
or encoder_name.startswith("vgg")
177-
or encoder_name.startswith("densenet")
178-
or encoder_name.startswith("timm-res")
179-
):
180-
return
181-
182-
encoder = smp.encoders.get_encoder(encoder_name, output_stride=16)
183-
184-
encoder.eval()
185-
with torch.no_grad():
186-
sample = torch.ones([1, 3, 64, 64])
187-
output = encoder(sample)
188-
189-
shapes = [out.shape[-1] for out in output]
190-
assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation
191-
192-
193-
if __name__ == "__main__":
194-
pytest.main([__file__])
1+
# import pytest
2+
# import torch
3+
4+
# import segmentation_models_pytorch as smp # noqa
5+
6+
7+
# def get_encoders():
8+
# exclude_encoders = [
9+
# "senet154",
10+
# "resnext101_32x16d",
11+
# "resnext101_32x32d",
12+
# "resnext101_32x48d",
13+
# ]
14+
# encoders = smp.encoders.get_encoder_names()
15+
# encoders = [e for e in encoders if e not in exclude_encoders]
16+
# encoders.append("tu-resnet34") # for timm universal traditional-like encoder
17+
# encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder
18+
# encoders.append("tu-darknet17") # for timm universal vgg-like encoder
19+
# return encoders
20+
21+
22+
# ENCODERS = get_encoders()
23+
# DEFAULT_ENCODER = "resnet18"
24+
25+
26+
# def get_sample(model_class):
27+
# if model_class in [
28+
# smp.FPN,
29+
# smp.Linknet,
30+
# smp.Unet,
31+
# smp.UnetPlusPlus,
32+
# smp.MAnet,
33+
# smp.Segformer,
34+
# ]:
35+
# sample = torch.ones([1, 3, 64, 64])
36+
# elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]:
37+
# sample = torch.ones([2, 3, 128, 128])
38+
# elif model_class in [smp.PSPNet, smp.UPerNet]:
39+
# # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input
40+
# # from PSPModule pooling in PSPNet/UPerNet.
41+
# sample = torch.ones([2, 3, 64, 64])
42+
# else:
43+
# raise ValueError("Not supported model class {}".format(model_class))
44+
# return sample
45+
46+
47+
# def _test_forward(model, sample, test_shape=False):
48+
# with torch.no_grad():
49+
# out = model(sample)
50+
# if test_shape:
51+
# assert out.shape[2:] == sample.shape[2:]
52+
53+
54+
# def _test_forward_backward(model, sample, test_shape=False):
55+
# out = model(sample)
56+
# out.mean().backward()
57+
# if test_shape:
58+
# assert out.shape[2:] == sample.shape[2:]
59+
60+
61+
# @pytest.mark.parametrize("encoder_name", ENCODERS)
62+
# @pytest.mark.parametrize("encoder_depth", [3, 5])
63+
# @pytest.mark.parametrize(
64+
# "model_class",
65+
# [
66+
# smp.FPN,
67+
# smp.PSPNet,
68+
# smp.Linknet,
69+
# smp.Unet,
70+
# smp.UnetPlusPlus,
71+
# smp.MAnet,
72+
# smp.UPerNet,
73+
# smp.Segformer,
74+
# ],
75+
# )
76+
# def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
77+
# if (
78+
# model_class is smp.Unet
79+
# or model_class is smp.UnetPlusPlus
80+
# or model_class is smp.MAnet
81+
# ):
82+
# kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
83+
# if model_class in [smp.UnetPlusPlus, smp.Linknet]:
84+
# if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"):
85+
# return # skip transformer-like model*
86+
# if model_class is smp.FPN and encoder_depth != 5:
87+
# if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"):
88+
# return # skip transformer-like model*
89+
# model = model_class(
90+
# encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
91+
# )
92+
# sample = get_sample(model_class)
93+
# model.eval()
94+
# if encoder_depth == 5 and model_class != smp.PSPNet:
95+
# test_shape = True
96+
# else:
97+
# test_shape = False
98+
99+
# _test_forward(model, sample, test_shape)
100+
101+
102+
# @pytest.mark.parametrize(
103+
# "model_class",
104+
# [
105+
# smp.PAN,
106+
# smp.FPN,
107+
# smp.PSPNet,
108+
# smp.Linknet,
109+
# smp.Unet,
110+
# smp.UnetPlusPlus,
111+
# smp.MAnet,
112+
# smp.DeepLabV3,
113+
# smp.DeepLabV3Plus,
114+
# smp.UPerNet,
115+
# smp.Segformer,
116+
# ],
117+
# )
118+
# def test_forward_backward(model_class):
119+
# sample = get_sample(model_class)
120+
# model = model_class(DEFAULT_ENCODER, encoder_weights=None)
121+
# _test_forward_backward(model, sample)
122+
123+
124+
# @pytest.mark.parametrize(
125+
# "model_class",
126+
# [
127+
# smp.PAN,
128+
# smp.FPN,
129+
# smp.PSPNet,
130+
# smp.Linknet,
131+
# smp.Unet,
132+
# smp.UnetPlusPlus,
133+
# smp.MAnet,
134+
# smp.DeepLabV3,
135+
# smp.DeepLabV3Plus,
136+
# smp.UPerNet,
137+
# smp.Segformer,
138+
# ],
139+
# )
140+
# def test_aux_output(model_class):
141+
# model = model_class(
142+
# DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)
143+
# )
144+
# sample = get_sample(model_class)
145+
# label_size = (sample.shape[0], 2)
146+
# mask, label = model(sample)
147+
# assert label.size() == label_size
148+
149+
150+
# @pytest.mark.parametrize("upsampling", [2, 4, 8])
151+
# @pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet])
152+
# def test_upsample(model_class, upsampling):
153+
# default_upsampling = 4 if model_class is smp.FPN else 8
154+
# model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling)
155+
# sample = get_sample(model_class)
156+
# mask = model(sample)
157+
# assert mask.size()[-1] / 64 == upsampling / default_upsampling
158+
159+
160+
# @pytest.mark.parametrize("model_class", [smp.FPN])
161+
# @pytest.mark.parametrize("in_channels", [1, 2, 4])
162+
# def test_in_channels(model_class, in_channels):
163+
# sample = torch.ones([1, in_channels, 64, 64])
164+
# model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels)
165+
# model.eval()
166+
# with torch.no_grad():
167+
# model(sample)
168+
169+
# assert model.encoder._in_channels == in_channels
170+
171+
172+
# @pytest.mark.parametrize("encoder_name", ENCODERS)
173+
# def test_dilation(encoder_name):
174+
# if (
175+
# encoder_name in ["inceptionresnetv2", "xception", "inceptionv4"]
176+
# or encoder_name.startswith("vgg")
177+
# or encoder_name.startswith("densenet")
178+
# or encoder_name.startswith("timm-res")
179+
# ):
180+
# return
181+
182+
# encoder = smp.encoders.get_encoder(encoder_name, output_stride=16)
183+
184+
# encoder.eval()
185+
# with torch.no_grad():
186+
# sample = torch.ones([1, 3, 64, 64])
187+
# output = encoder(sample)
188+
189+
# shapes = [out.shape[-1] for out in output]
190+
# assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation
191+
192+
193+
# if __name__ == "__main__":
194+
# pytest.main([__file__])

0 commit comments

Comments
 (0)