Skip to content

Commit 6b2ca90

Browse files
committed
Enable any res for Unet and better docs
1 parent 5b75f11 commit 6b2ca90

File tree

2 files changed

+101
-40
lines changed

2 files changed

+101
-40
lines changed

segmentation_models_pytorch/decoders/unet/decoder.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,22 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5+
from typing import Optional, Sequence
56
from segmentation_models_pytorch.base import modules as md
67

78

89
class DecoderBlock(nn.Module):
910
def __init__(
1011
self,
11-
in_channels,
12-
skip_channels,
13-
out_channels,
14-
use_batchnorm=True,
15-
attention_type=None,
12+
in_channels: int,
13+
skip_channels: int,
14+
out_channels: int,
15+
use_batchnorm: bool = True,
16+
attention_type: Optional[str] = None,
17+
interpolation_mode: str = "nearest",
1618
):
1719
super().__init__()
20+
self.interpolate_mode = interpolation_mode
1821
self.conv1 = md.Conv2dReLU(
1922
in_channels + skip_channels,
2023
out_channels,
@@ -34,19 +37,32 @@ def __init__(
3437
)
3538
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
3639

37-
def forward(self, x, skip=None):
38-
x = F.interpolate(x, scale_factor=2, mode="nearest")
39-
if skip is not None:
40-
x = torch.cat([x, skip], dim=1)
41-
x = self.attention1(x)
42-
x = self.conv1(x)
43-
x = self.conv2(x)
44-
x = self.attention2(x)
45-
return x
40+
def forward(
41+
self,
42+
feature_map: torch.Tensor,
43+
target_height: int,
44+
target_width: int,
45+
skip_connection: Optional[torch.Tensor] = None,
46+
) -> torch.Tensor:
47+
"""Upsample feature map to the given spatial shape, concatenate with skip connection,
48+
apply attention block (if specified) and then apply two convolutions.
49+
"""
50+
feature_map = F.interpolate(
51+
feature_map, size=(target_height, target_width), mode=self.interpolate_mode
52+
)
53+
if skip_connection is not None:
54+
feature_map = torch.cat([feature_map, skip_connection], dim=1)
55+
feature_map = self.attention1(feature_map)
56+
feature_map = self.conv1(feature_map)
57+
feature_map = self.conv2(feature_map)
58+
feature_map = self.attention2(feature_map)
59+
return feature_map
4660

4761

4862
class CenterBlock(nn.Sequential):
49-
def __init__(self, in_channels, out_channels, use_batchnorm=True):
63+
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""
64+
65+
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
5066
conv1 = md.Conv2dReLU(
5167
in_channels,
5268
out_channels,
@@ -67,12 +83,12 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
6783
class UnetDecoder(nn.Module):
6884
def __init__(
6985
self,
70-
encoder_channels,
71-
decoder_channels,
72-
n_blocks=5,
73-
use_batchnorm=True,
74-
attention_type=None,
75-
center=False,
86+
encoder_channels: Sequence[int],
87+
decoder_channels: Sequence[int],
88+
n_blocks: int = 5,
89+
use_batchnorm: bool = True,
90+
attention_type: Optional[str] = None,
91+
add_center_block: bool = False,
7692
):
7793
super().__init__()
7894

@@ -94,31 +110,44 @@ def __init__(
94110
skip_channels = list(encoder_channels[1:]) + [0]
95111
out_channels = decoder_channels
96112

97-
if center:
113+
if add_center_block:
98114
self.center = CenterBlock(
99115
head_channels, head_channels, use_batchnorm=use_batchnorm
100116
)
101117
else:
102118
self.center = nn.Identity()
103119

104120
# combine decoder keyword arguments
105-
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
106-
blocks = [
107-
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
108-
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
109-
]
110-
self.blocks = nn.ModuleList(blocks)
111-
112-
def forward(self, *features):
121+
self.blocks = nn.ModuleList()
122+
for block_in_channels, block_skip_channels, block_out_channels in zip(
123+
in_channels, skip_channels, out_channels
124+
):
125+
block = DecoderBlock(
126+
block_in_channels,
127+
block_skip_channels,
128+
block_out_channels,
129+
use_batchnorm=use_batchnorm,
130+
attention_type=attention_type,
131+
)
132+
self.blocks.append(block)
133+
134+
def forward(self, *features: torch.Tensor) -> torch.Tensor:
135+
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
136+
spatial_shapes = [feature.shape[2:] for feature in features]
137+
spatial_shapes = spatial_shapes[::-1]
138+
113139
features = features[1:] # remove first skip with same spatial resolution
114140
features = features[::-1] # reverse channels to start from head of encoder
115141

116142
head = features[0]
117-
skips = features[1:]
143+
skip_connections = features[1:]
118144

119145
x = self.center(head)
146+
120147
for i, decoder_block in enumerate(self.blocks):
121-
skip = skips[i] if i < len(skips) else None
122-
x = decoder_block(x, skip)
148+
# upsample to the next spatial shape
149+
height, width = spatial_shapes[i + 1]
150+
skip_connection = skip_connections[i] if i < len(skip_connections) else None
151+
x = decoder_block(x, height, width, skip_connection=skip_connection)
123152

124153
return x

segmentation_models_pytorch/decoders/unet/model.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union, Tuple, Callable
1+
from typing import Any, Optional, Union, Callable, Sequence
22

33
from segmentation_models_pytorch.base import (
44
ClassificationHead,
@@ -12,10 +12,21 @@
1212

1313

1414
class Unet(SegmentationModel):
15-
"""Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
16-
and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
17-
resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
18-
for fusing decoder blocks with skip connections.
15+
"""
16+
U-Net is a fully convolutional neural network architecture designed for semantic image segmentation.
17+
18+
It consists of two main parts:
19+
20+
1. An encoder (downsampling path) that extracts increasingly abstract features
21+
2. A decoder (upsampling path) that gradually recovers spatial details
22+
23+
The key is the use of skip connections between corresponding encoder and decoder layers.
24+
These connections allow the decoder to access fine-grained details from earlier encoder layers,
25+
which helps produce more precise segmentation masks.
26+
27+
The skip connections work by concatenating feature maps from the encoder directly into the decoder
28+
at corresponding resolutions. This helps preserve important spatial information that would
29+
otherwise be lost during the encoding process.
1930
2031
Args:
2132
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
@@ -51,19 +62,39 @@ class Unet(SegmentationModel):
5162
Returns:
5263
``torch.nn.Module``: Unet
5364
65+
Example:
66+
.. code-block:: python
67+
68+
import torch
69+
import segmentation_models_pytorch as smp
70+
71+
model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5)
72+
model.eval()
73+
74+
# generate random images
75+
images = torch.rand(2, 3, 256, 256)
76+
77+
with torch.inference_mode():
78+
mask = model(images)
79+
80+
print(mask.shape)
81+
# torch.Size([2, 5, 256, 256])
82+
5483
.. _Unet:
5584
https://arxiv.org/abs/1505.04597
5685
5786
"""
5887

88+
requires_divisible_input_shape = False
89+
5990
@supports_config_loading
6091
def __init__(
6192
self,
6293
encoder_name: str = "resnet34",
6394
encoder_depth: int = 5,
6495
encoder_weights: Optional[str] = "imagenet",
6596
decoder_use_batchnorm: bool = True,
66-
decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
97+
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
6798
decoder_attention_type: Optional[str] = None,
6899
in_channels: int = 3,
69100
classes: int = 1,
@@ -81,12 +112,13 @@ def __init__(
81112
**kwargs,
82113
)
83114

115+
add_center_block = encoder_name.startswith("vgg")
84116
self.decoder = UnetDecoder(
85117
encoder_channels=self.encoder.out_channels,
86118
decoder_channels=decoder_channels,
87119
n_blocks=encoder_depth,
88120
use_batchnorm=decoder_use_batchnorm,
89-
center=True if encoder_name.startswith("vgg") else False,
121+
add_center_block=add_center_block,
90122
attention_type=decoder_attention_type,
91123
)
92124

0 commit comments

Comments
 (0)