Skip to content

Commit bf57b0e

Browse files
committed
update tests/test_models
1 parent 8c91c09 commit bf57b0e

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

tests/test_models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def get_sample(model_class):
2828
smp.Unet,
2929
smp.UnetPlusPlus,
3030
smp.MAnet,
31+
smp.Segformer,
3132
]:
3233
sample = torch.ones([1, 3, 64, 64])
3334
elif model_class == smp.PAN:
@@ -61,7 +62,16 @@ def _test_forward_backward(model, sample, test_shape=False):
6162
@pytest.mark.parametrize("encoder_depth", [3, 5])
6263
@pytest.mark.parametrize(
6364
"model_class",
64-
[smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet],
65+
[
66+
smp.FPN,
67+
smp.PSPNet,
68+
smp.Linknet,
69+
smp.Unet,
70+
smp.UnetPlusPlus,
71+
smp.MAnet,
72+
smp.UPerNet,
73+
smp.Segformer,
74+
],
6575
)
6676
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
6777
if (
@@ -106,6 +116,7 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
106116
smp.DeepLabV3,
107117
smp.DeepLabV3Plus,
108118
smp.UPerNet,
119+
smp.Segformer,
109120
],
110121
)
111122
def test_forward_backward(model_class):
@@ -127,6 +138,7 @@ def test_forward_backward(model_class):
127138
smp.DeepLabV3,
128139
smp.DeepLabV3Plus,
129140
smp.UPerNet,
141+
smp.Segformer,
130142
],
131143
)
132144
def test_aux_output(model_class):

0 commit comments

Comments
 (0)