22import torch .nn as nn
33import torch .nn .functional as F
44
5+ from typing import Optional , Sequence
56from segmentation_models_pytorch .base import modules as md
67
78
8- class DecoderBlock (nn .Module ):
9+ class UnetDecoderBlock (nn .Module ):
10+ """A decoder block in the U-Net architecture that performs upsampling and feature fusion."""
11+
912 def __init__ (
1013 self ,
11- in_channels ,
12- skip_channels ,
13- out_channels ,
14- use_batchnorm = True ,
15- attention_type = None ,
14+ in_channels : int ,
15+ skip_channels : int ,
16+ out_channels : int ,
17+ use_batchnorm : bool = True ,
18+ attention_type : Optional [str ] = None ,
19+ interpolation_mode : str = "nearest" ,
1620 ):
1721 super ().__init__ ()
22+ self .interpolation_mode = interpolation_mode
1823 self .conv1 = md .Conv2dReLU (
1924 in_channels + skip_channels ,
2025 out_channels ,
@@ -34,19 +39,31 @@ def __init__(
3439 )
3540 self .attention2 = md .Attention (attention_type , in_channels = out_channels )
3641
37- def forward (self , x , skip = None ):
38- x = F .interpolate (x , scale_factor = 2 , mode = "nearest" )
39- if skip is not None :
40- x = torch .cat ([x , skip ], dim = 1 )
41- x = self .attention1 (x )
42- x = self .conv1 (x )
43- x = self .conv2 (x )
44- x = self .attention2 (x )
45- return x
42+ def forward (
43+ self ,
44+ feature_map : torch .Tensor ,
45+ target_height : int ,
46+ target_width : int ,
47+ skip_connection : Optional [torch .Tensor ] = None ,
48+ ) -> torch .Tensor :
49+ feature_map = F .interpolate (
50+ feature_map ,
51+ size = (target_height , target_width ),
52+ mode = self .interpolation_mode ,
53+ )
54+ if skip_connection is not None :
55+ feature_map = torch .cat ([feature_map , skip_connection ], dim = 1 )
56+ feature_map = self .attention1 (feature_map )
57+ feature_map = self .conv1 (feature_map )
58+ feature_map = self .conv2 (feature_map )
59+ feature_map = self .attention2 (feature_map )
60+ return feature_map
61+
4662
63+ class UnetCenterBlock (nn .Sequential ):
64+ """Center block of the Unet decoder. Applied to the last feature map of the encoder."""
4765
48- class CenterBlock (nn .Sequential ):
49- def __init__ (self , in_channels , out_channels , use_batchnorm = True ):
66+ def __init__ (self , in_channels : int , out_channels : int , use_batchnorm : bool = True ):
5067 conv1 = md .Conv2dReLU (
5168 in_channels ,
5269 out_channels ,
@@ -65,14 +82,21 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
6582
6683
6784class UnetDecoder (nn .Module ):
85+ """The decoder part of the U-Net architecture.
86+
87+ Takes encoded features from different stages of the encoder and progressively upsamples them while
88+ combining with skip connections. This helps preserve fine-grained details in the final segmentation.
89+ """
90+
6891 def __init__ (
6992 self ,
70- encoder_channels ,
71- decoder_channels ,
72- n_blocks = 5 ,
73- use_batchnorm = True ,
74- attention_type = None ,
75- center = False ,
93+ encoder_channels : Sequence [int ],
94+ decoder_channels : Sequence [int ],
95+ n_blocks : int = 5 ,
96+ use_batchnorm : bool = True ,
97+ attention_type : Optional [str ] = None ,
98+ add_center_block : bool = False ,
99+ interpolation_mode : str = "nearest" ,
76100 ):
77101 super ().__init__ ()
78102
@@ -94,31 +118,45 @@ def __init__(
94118 skip_channels = list (encoder_channels [1 :]) + [0 ]
95119 out_channels = decoder_channels
96120
97- if center :
98- self .center = CenterBlock (
121+ if add_center_block :
122+ self .center = UnetCenterBlock (
99123 head_channels , head_channels , use_batchnorm = use_batchnorm
100124 )
101125 else :
102126 self .center = nn .Identity ()
103127
104128 # combine decoder keyword arguments
105- kwargs = dict (use_batchnorm = use_batchnorm , attention_type = attention_type )
106- blocks = [
107- DecoderBlock (in_ch , skip_ch , out_ch , ** kwargs )
108- for in_ch , skip_ch , out_ch in zip (in_channels , skip_channels , out_channels )
109- ]
110- self .blocks = nn .ModuleList (blocks )
111-
112- def forward (self , * features ):
129+ self .blocks = nn .ModuleList ()
130+ for block_in_channels , block_skip_channels , block_out_channels in zip (
131+ in_channels , skip_channels , out_channels
132+ ):
133+ block = UnetDecoderBlock (
134+ block_in_channels ,
135+ block_skip_channels ,
136+ block_out_channels ,
137+ use_batchnorm = use_batchnorm ,
138+ attention_type = attention_type ,
139+ interpolation_mode = interpolation_mode ,
140+ )
141+ self .blocks .append (block )
142+
143+ def forward (self , * features : torch .Tensor ) -> torch .Tensor :
144+ # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
145+ spatial_shapes = [feature .shape [2 :] for feature in features ]
146+ spatial_shapes = spatial_shapes [::- 1 ]
147+
113148 features = features [1 :] # remove first skip with same spatial resolution
114149 features = features [::- 1 ] # reverse channels to start from head of encoder
115150
116151 head = features [0 ]
117- skips = features [1 :]
152+ skip_connections = features [1 :]
118153
119154 x = self .center (head )
155+
120156 for i , decoder_block in enumerate (self .blocks ):
121- skip = skips [i ] if i < len (skips ) else None
122- x = decoder_block (x , skip )
157+ # upsample to the next spatial shape
158+ height , width = spatial_shapes [i + 1 ]
159+ skip_connection = skip_connections [i ] if i < len (skip_connections ) else None
160+ x = decoder_block (x , height , width , skip_connection = skip_connection )
123161
124162 return x
0 commit comments