From ea54d936a64d032cba7133967bface2b43571da2 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 10 Oct 2024 23:03:55 +0800 Subject: [PATCH 1/2] Update test_models.py --- tests/test_models.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index a1b5f2c6..858065e1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -29,12 +29,11 @@ def get_sample(model_class): smp.PSPNet, smp.UnetPlusPlus, smp.MAnet, - smp.UPerNet, ]: sample = torch.ones([1, 3, 64, 64]) - elif model_class == smp.PAN: + elif model_class == smp.PAN or model_class == smp.UPerNet: sample = torch.ones([2, 3, 256, 256]) - elif model_class == smp.DeepLabV3: + elif model_class == smp.DeepLabV3 or model_class == smp.DeepLabV3Plus: sample = torch.ones([2, 3, 128, 128]) else: raise ValueError("Not supported model class {}".format(model_class)) @@ -102,6 +101,8 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs): smp.UnetPlusPlus, smp.MAnet, smp.DeepLabV3, + smp.DeepLabV3Plus, + smp.UPerNet, ], ) def test_forward_backward(model_class): @@ -112,7 +113,18 @@ def test_forward_backward(model_class): @pytest.mark.parametrize( "model_class", - [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet], + [ + smp.PAN, + smp.FPN, + smp.PSPNet, + smp.Linknet, + smp.Unet, + smp.UnetPlusPlus, + smp.MAnet, + smp.DeepLabV3, + smp.DeepLabV3Plus, + smp.UPerNet, + ], ) def test_aux_output(model_class): model = model_class( From 8df56ddc2c45f97ca1a33ed133ae9282daedc9d9 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sat, 12 Oct 2024 10:28:42 +0800 Subject: [PATCH 2/2] Update test_models.py --- tests/test_models.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 858065e1..206635e2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -23,18 +23,21 @@ def get_encoders(): def get_sample(model_class): if model_class in [ - smp.Unet, - smp.Linknet, smp.FPN, - smp.PSPNet, + smp.Linknet, + smp.Unet, smp.UnetPlusPlus, smp.MAnet, ]: sample = torch.ones([1, 3, 64, 64]) - elif model_class == smp.PAN or model_class == smp.UPerNet: + elif model_class == smp.PAN: sample = torch.ones([2, 3, 256, 256]) - elif model_class == smp.DeepLabV3 or model_class == smp.DeepLabV3Plus: + elif model_class in [smp.DeepLabV3, smp.DeepLabV3Plus]: sample = torch.ones([2, 3, 128, 128]) + elif model_class in [smp.PSPNet, smp.UPerNet]: + # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input + # from PSPModule pooling in PSPNet/UPerNet. + sample = torch.ones([2, 3, 64, 64]) else: raise ValueError("Not supported model class {}".format(model_class)) return sample