Skip to content

Commit d5a80df

Browse files
committed
Interpolation for unet
1 parent eb81c1f commit d5a80df

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

segmentation_models_pytorch/decoders/unet/decoder.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from segmentation_models_pytorch.base import modules as md
77

88

9-
class DecoderBlock(nn.Module):
9+
class UnetDecoderBlock(nn.Module):
10+
"""A decoder block in the U-Net architecture that performs upsampling and feature fusion."""
11+
1012
def __init__(
1113
self,
1214
in_channels: int,
@@ -17,7 +19,7 @@ def __init__(
1719
interpolation_mode: str = "nearest",
1820
):
1921
super().__init__()
20-
self.interpolate_mode = interpolation_mode
22+
self.interpolation_mode = interpolation_mode
2123
self.conv1 = md.Conv2dReLU(
2224
in_channels + skip_channels,
2325
out_channels,
@@ -44,11 +46,10 @@ def forward(
4446
target_width: int,
4547
skip_connection: Optional[torch.Tensor] = None,
4648
) -> torch.Tensor:
47-
"""Upsample feature map to the given spatial shape, concatenate with skip connection,
48-
apply attention block (if specified) and then apply two convolutions.
49-
"""
5049
feature_map = F.interpolate(
51-
feature_map, size=(target_height, target_width), mode=self.interpolate_mode
50+
feature_map,
51+
size=(target_height, target_width),
52+
mode=self.interpolation_mode,
5253
)
5354
if skip_connection is not None:
5455
feature_map = torch.cat([feature_map, skip_connection], dim=1)
@@ -59,7 +60,7 @@ def forward(
5960
return feature_map
6061

6162

62-
class CenterBlock(nn.Sequential):
63+
class UnetCenterBlock(nn.Sequential):
6364
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""
6465

6566
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
@@ -81,6 +82,12 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
8182

8283

8384
class UnetDecoder(nn.Module):
85+
"""The decoder part of the U-Net architecture.
86+
87+
Takes encoded features from different stages of the encoder and progressively upsamples them while
88+
combining with skip connections. This helps preserve fine-grained details in the final segmentation.
89+
"""
90+
8491
def __init__(
8592
self,
8693
encoder_channels: Sequence[int],
@@ -89,6 +96,7 @@ def __init__(
8996
use_batchnorm: bool = True,
9097
attention_type: Optional[str] = None,
9198
add_center_block: bool = False,
99+
interpolation_mode: str = "nearest",
92100
):
93101
super().__init__()
94102

@@ -111,7 +119,7 @@ def __init__(
111119
out_channels = decoder_channels
112120

113121
if add_center_block:
114-
self.center = CenterBlock(
122+
self.center = UnetCenterBlock(
115123
head_channels, head_channels, use_batchnorm=use_batchnorm
116124
)
117125
else:
@@ -122,12 +130,13 @@ def __init__(
122130
for block_in_channels, block_skip_channels, block_out_channels in zip(
123131
in_channels, skip_channels, out_channels
124132
):
125-
block = DecoderBlock(
133+
block = UnetDecoderBlock(
126134
block_in_channels,
127135
block_skip_channels,
128136
block_out_channels,
129137
use_batchnorm=use_batchnorm,
130138
attention_type=attention_type,
139+
interpolation_mode=interpolation_mode,
131140
)
132141
self.blocks.append(block)
133142

segmentation_models_pytorch/decoders/unet/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class Unet(SegmentationModel):
4444
Available options are **True, False, "inplace"**
4545
decoder_attention_type: Attention module used in decoder of the model. Available options are
4646
**None** and **scse** (https://arxiv.org/abs/1808.08127).
47+
decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are
48+
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
4749
in_channels: A number of input channels for the model, default is 3 (RGB images)
4850
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
4951
activation: An activation function to apply after the final convolution layer.
@@ -96,6 +98,7 @@ def __init__(
9698
decoder_use_batchnorm: bool = True,
9799
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
98100
decoder_attention_type: Optional[str] = None,
101+
decoder_interpolation_mode: str = "nearest",
99102
in_channels: int = 3,
100103
classes: int = 1,
101104
activation: Optional[Union[str, Callable]] = None,
@@ -120,6 +123,7 @@ def __init__(
120123
use_batchnorm=decoder_use_batchnorm,
121124
add_center_block=add_center_block,
122125
attention_type=decoder_attention_type,
126+
interpolation_mode=decoder_interpolation_mode,
123127
)
124128

125129
self.segmentation_head = SegmentationHead(

0 commit comments

Comments
 (0)