6
6
from segmentation_models_pytorch .base import modules as md
7
7
8
8
9
- class DecoderBlock (nn .Module ):
9
+ class UnetDecoderBlock (nn .Module ):
10
+ """A decoder block in the U-Net architecture that performs upsampling and feature fusion."""
11
+
10
12
def __init__ (
11
13
self ,
12
14
in_channels : int ,
@@ -17,7 +19,7 @@ def __init__(
17
19
interpolation_mode : str = "nearest" ,
18
20
):
19
21
super ().__init__ ()
20
- self .interpolate_mode = interpolation_mode
22
+ self .interpolation_mode = interpolation_mode
21
23
self .conv1 = md .Conv2dReLU (
22
24
in_channels + skip_channels ,
23
25
out_channels ,
@@ -44,11 +46,10 @@ def forward(
44
46
target_width : int ,
45
47
skip_connection : Optional [torch .Tensor ] = None ,
46
48
) -> torch .Tensor :
47
- """Upsample feature map to the given spatial shape, concatenate with skip connection,
48
- apply attention block (if specified) and then apply two convolutions.
49
- """
50
49
feature_map = F .interpolate (
51
- feature_map , size = (target_height , target_width ), mode = self .interpolate_mode
50
+ feature_map ,
51
+ size = (target_height , target_width ),
52
+ mode = self .interpolation_mode ,
52
53
)
53
54
if skip_connection is not None :
54
55
feature_map = torch .cat ([feature_map , skip_connection ], dim = 1 )
@@ -59,7 +60,7 @@ def forward(
59
60
return feature_map
60
61
61
62
62
- class CenterBlock (nn .Sequential ):
63
+ class UnetCenterBlock (nn .Sequential ):
63
64
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""
64
65
65
66
def __init__ (self , in_channels : int , out_channels : int , use_batchnorm : bool = True ):
@@ -81,6 +82,12 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
81
82
82
83
83
84
class UnetDecoder (nn .Module ):
85
+ """The decoder part of the U-Net architecture.
86
+
87
+ Takes encoded features from different stages of the encoder and progressively upsamples them while
88
+ combining with skip connections. This helps preserve fine-grained details in the final segmentation.
89
+ """
90
+
84
91
def __init__ (
85
92
self ,
86
93
encoder_channels : Sequence [int ],
@@ -89,6 +96,7 @@ def __init__(
89
96
use_batchnorm : bool = True ,
90
97
attention_type : Optional [str ] = None ,
91
98
add_center_block : bool = False ,
99
+ interpolation_mode : str = "nearest" ,
92
100
):
93
101
super ().__init__ ()
94
102
@@ -111,7 +119,7 @@ def __init__(
111
119
out_channels = decoder_channels
112
120
113
121
if add_center_block :
114
- self .center = CenterBlock (
122
+ self .center = UnetCenterBlock (
115
123
head_channels , head_channels , use_batchnorm = use_batchnorm
116
124
)
117
125
else :
@@ -122,12 +130,13 @@ def __init__(
122
130
for block_in_channels , block_skip_channels , block_out_channels in zip (
123
131
in_channels , skip_channels , out_channels
124
132
):
125
- block = DecoderBlock (
133
+ block = UnetDecoderBlock (
126
134
block_in_channels ,
127
135
block_skip_channels ,
128
136
block_out_channels ,
129
137
use_batchnorm = use_batchnorm ,
130
138
attention_type = attention_type ,
139
+ interpolation_mode = interpolation_mode ,
131
140
)
132
141
self .blocks .append (block )
133
142
0 commit comments