-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Description
in segmentation_models_pytorch-0.5.0,tu-convnextv2_base cannot be combined with UnetPlusPlus
code
import torch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.decoders.unetplusplus.decoder import DecoderBlock
# 1. Create model, enable features_only mode
model = smp.UnetPlusPlus(
encoder_name='tu-convnextv2_base',
encoder_weights="imagenet",
in_channels=3, # Input: RGB image
encoder_depth=4, # Encoder depth
decoder_channels=(1024, 512, 256,128),
decoder_use_norm='batchnorm',
classes=19,
decoder_interpolation= "bicubic",
activation=None , # Key: Assuming your labels are normalized to [0, 1]
)
dummy_input = torch.randn(1, 3, 512, 1024)
features = model(dummy_input)output
C:\Users\123\miniconda3\Lib\site-packages\torch\nn\init.py:582: UserWarning: Initializing zero-element tensors is a no-op
warnings.warn("Initializing zero-element tensors is a no-op")
Traceback (most recent call last):
File "c:\Users\123\smp_test\testsmp1.py", line 24, in <module>
features = model(dummy_input)
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\123\miniconda3\Lib\site-packages\segmentation_models_pytorch\base\model.py", line 67, in forward
decoder_output = self.decoder(features)
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\123\miniconda3\Lib\site-packages\segmentation_models_pytorch\decoders\unetplusplus\decoder.py", line 153, in forward
output = self.blocks[f"x_{depth_idx}_{depth_idx}"](
features[depth_idx], features[depth_idx + 1]
)
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\123\miniconda3\Lib\site-packages\segmentation_models_pytorch\decoders\unetplusplus\decoder.py", line 48, in forward
x = self.conv1(x)
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\container.py", line 244, in forward
input = module(input)
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\conv.py", line 548, in forward
return self._conv_forward(input, self.weight, self.bias)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\123\miniconda3\Lib\site-packages\torch\nn\modules\conv.py", line 543, in _conv_forward
return F.conv2d(
~~~~~~~~^
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
RuntimeError: Given groups=1, expected weight to be at least 1 at dimension 0, but got weight of size [0, 128, 3, 3] insteadwhy
The following code exists in Lib\site-packages\segmentation_models_pytorch\encoders\timm_universal.py,which causes the encoder to expose an out_channels list containing a 0 channel. Consequently, when the UnetPlusPlus decoder is constructed, it erroneously builds a convolution layer with an output channel size of 0.
if self._is_transformer_style:
# Transformer-like models (start at scale 4)
if "tresnet" in name:
# 'tresnet' models start feature extraction at stage 1,
# so out_indices=(1, 2, 3, 4) for depth=5.
common_kwargs["out_indices"] = tuple(range(1, depth))
else:
# Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5.
common_kwargs["out_indices"] = tuple(range(depth - 1))
timm_model_kwargs = _merge_kwargs_no_duplicates(common_kwargs, kwargs)
self.model = timm.create_model(name, **timm_model_kwargs)
# Add a dummy output channel (0) to align with traditional encoder structures.
self._out_channels = (
[in_channels] + [0] + self.model.feature_info.channels()
)
...
@property
def out_channels(self) -> list[int]:
"""
Returns the number of output channels for each feature stage.
Returns:
list[int]: A list of channel dimensions at each scale.
"""
return self._out_channelsSimply modifying the channel count is insufficient because the following code also exists inLib\site-packages\segmentation_models_pytorch\encoders\timm_universal.py
if self._is_transformer_style:
B, _, H, W = x.shape
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
features = [dummy] + featuresAt this point, the encoder's output includes a feature map with 0 channels. During the execution of the UnetPlusPlus decoder, another error will occur because, after concatenating the feature maps, the input channel requirement for a specific convolution block within the dense connections is not met.