Skip to content

tu-convnextv2_base cannot be combined with UnetPlusPlus #1260

@wenwwww

Description

@wenwwww

in segmentation_models_pytorch-0.5.0,tu-convnextv2_base cannot be combined with UnetPlusPlus

code

import torch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.decoders.unetplusplus.decoder import DecoderBlock


# 1. Create model, enable features_only mode
model = smp.UnetPlusPlus(
    encoder_name='tu-convnextv2_base',    
    encoder_weights="imagenet",     
    in_channels=3,                   # Input: RGB image
    encoder_depth=4,                 # Encoder depth
    decoder_channels=(1024, 512, 256,128),
    decoder_use_norm='batchnorm',
    classes=19,                      
    decoder_interpolation= "bicubic",
    activation=None ,  # Key: Assuming your labels are normalized to [0, 1]
)

dummy_input = torch.randn(1, 3, 512, 1024)
features = model(dummy_input)

output

C:\Users\123\miniconda3\Lib\site-packages\torch\nn\init.py:582: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
Traceback (most recent call last):
  File "c:\Users\123\smp_test\testsmp1.py", line 24, in <module>
    features = model(dummy_input)
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\123\miniconda3\Lib\site-packages\segmentation_models_pytorch\base\model.py", line 67, in forward
    decoder_output = self.decoder(features)
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\123\miniconda3\Lib\site-packages\segmentation_models_pytorch\decoders\unetplusplus\decoder.py", line 153, in forward
    output = self.blocks[f"x_{depth_idx}_{depth_idx}"](
        features[depth_idx], features[depth_idx + 1]
    )
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\123\miniconda3\Lib\site-packages\segmentation_models_pytorch\decoders\unetplusplus\decoder.py", line 48, in forward
    x = self.conv1(x)
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\container.py", line 244, in forward
    input = module(input)
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\conv.py", line 548, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\conv.py", line 543, in _conv_forward
    return F.conv2d(
           ~~~~~~~~^
        input, weight, bias, self.stride, self.padding, self.dilation, self.groups
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
RuntimeError: Given groups=1, expected weight to be at least 1 at dimension 0, but got weight of size [0, 128, 3, 3] instead

why

The following code exists in Lib\site-packages\segmentation_models_pytorch\encoders\timm_universal.py,which causes the encoder to expose an out_channels list containing a 0 channel. Consequently, when the UnetPlusPlus decoder is constructed, it erroneously builds a convolution layer with an output channel size of 0.

if self._is_transformer_style:
            # Transformer-like models (start at scale 4)
            if "tresnet" in name:
                # 'tresnet' models start feature extraction at stage 1,
                # so out_indices=(1, 2, 3, 4) for depth=5.
                common_kwargs["out_indices"] = tuple(range(1, depth))
            else:
                # Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5.
                common_kwargs["out_indices"] = tuple(range(depth - 1))

            timm_model_kwargs = _merge_kwargs_no_duplicates(common_kwargs, kwargs)
            self.model = timm.create_model(name, **timm_model_kwargs)

            # Add a dummy output channel (0) to align with traditional encoder structures.
            self._out_channels = (
                [in_channels] + [0] + self.model.feature_info.channels()
            )

...
@property
def out_channels(self) -> list[int]:
        """
        Returns the number of output channels for each feature stage.

        Returns:
            list[int]: A list of channel dimensions at each scale.
        """
        return self._out_channels

Simply modifying the channel count is insufficient because the following code also exists inLib\site-packages\segmentation_models_pytorch\encoders\timm_universal.py

if self._is_transformer_style:
            B, _, H, W = x.shape
            dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
            features = [dummy] + features

At this point, the encoder's output includes a feature map with 0 channels. During the execution of the UnetPlusPlus decoder, another error will occur because, after concatenating the feature maps, the input channel requirement for a specific convolution block within the dense connections is not met.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions