Skip to content

Commit 748070e

Browse files
committed
update tests
add UPerNet for test_models
1 parent 43000a3 commit 748070e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/test_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def get_sample(model_class):
2929
smp.PSPNet,
3030
smp.UnetPlusPlus,
3131
smp.MAnet,
32+
smp.UPerNet,
3233
]:
3334
sample = torch.ones([1, 3, 64, 64])
3435
elif model_class == smp.PAN:
@@ -57,7 +58,8 @@ def _test_forward_backward(model, sample, test_shape=False):
5758
@pytest.mark.parametrize("encoder_name", ENCODERS)
5859
@pytest.mark.parametrize("encoder_depth", [3, 5])
5960
@pytest.mark.parametrize(
60-
"model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]
61+
"model_class",
62+
[smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet],
6163
)
6264
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
6365
if (

0 commit comments

Comments
 (0)