2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
4
5
+ from typing import Optional , Sequence
5
6
from segmentation_models_pytorch .base import modules as md
6
7
7
8
8
9
class DecoderBlock (nn .Module ):
9
10
def __init__ (
10
11
self ,
11
- in_channels ,
12
- skip_channels ,
13
- out_channels ,
14
- use_batchnorm = True ,
15
- attention_type = None ,
12
+ in_channels : int ,
13
+ skip_channels : int ,
14
+ out_channels : int ,
15
+ use_batchnorm : bool = True ,
16
+ attention_type : Optional [str ] = None ,
17
+ interpolation_mode : str = "nearest" ,
16
18
):
17
19
super ().__init__ ()
20
+ self .interpolate_mode = interpolation_mode
18
21
self .conv1 = md .Conv2dReLU (
19
22
in_channels + skip_channels ,
20
23
out_channels ,
@@ -34,19 +37,32 @@ def __init__(
34
37
)
35
38
self .attention2 = md .Attention (attention_type , in_channels = out_channels )
36
39
37
- def forward (self , x , skip = None ):
38
- x = F .interpolate (x , scale_factor = 2 , mode = "nearest" )
39
- if skip is not None :
40
- x = torch .cat ([x , skip ], dim = 1 )
41
- x = self .attention1 (x )
42
- x = self .conv1 (x )
43
- x = self .conv2 (x )
44
- x = self .attention2 (x )
45
- return x
40
+ def forward (
41
+ self ,
42
+ feature_map : torch .Tensor ,
43
+ target_height : int ,
44
+ target_width : int ,
45
+ skip_connection : Optional [torch .Tensor ] = None ,
46
+ ) -> torch .Tensor :
47
+ """Upsample feature map to the given spatial shape, concatenate with skip connection,
48
+ apply attention block (if specified) and then apply two convolutions.
49
+ """
50
+ feature_map = F .interpolate (
51
+ feature_map , size = (target_height , target_width ), mode = self .interpolate_mode
52
+ )
53
+ if skip_connection is not None :
54
+ feature_map = torch .cat ([feature_map , skip_connection ], dim = 1 )
55
+ feature_map = self .attention1 (feature_map )
56
+ feature_map = self .conv1 (feature_map )
57
+ feature_map = self .conv2 (feature_map )
58
+ feature_map = self .attention2 (feature_map )
59
+ return feature_map
46
60
47
61
48
62
class CenterBlock (nn .Sequential ):
49
- def __init__ (self , in_channels , out_channels , use_batchnorm = True ):
63
+ """Center block of the Unet decoder. Applied to the last feature map of the encoder."""
64
+
65
+ def __init__ (self , in_channels : int , out_channels : int , use_batchnorm : bool = True ):
50
66
conv1 = md .Conv2dReLU (
51
67
in_channels ,
52
68
out_channels ,
@@ -67,12 +83,12 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
67
83
class UnetDecoder (nn .Module ):
68
84
def __init__ (
69
85
self ,
70
- encoder_channels ,
71
- decoder_channels ,
72
- n_blocks = 5 ,
73
- use_batchnorm = True ,
74
- attention_type = None ,
75
- center = False ,
86
+ encoder_channels : Sequence [ int ] ,
87
+ decoder_channels : Sequence [ int ] ,
88
+ n_blocks : int = 5 ,
89
+ use_batchnorm : bool = True ,
90
+ attention_type : Optional [ str ] = None ,
91
+ add_center_block : bool = False ,
76
92
):
77
93
super ().__init__ ()
78
94
@@ -94,31 +110,44 @@ def __init__(
94
110
skip_channels = list (encoder_channels [1 :]) + [0 ]
95
111
out_channels = decoder_channels
96
112
97
- if center :
113
+ if add_center_block :
98
114
self .center = CenterBlock (
99
115
head_channels , head_channels , use_batchnorm = use_batchnorm
100
116
)
101
117
else :
102
118
self .center = nn .Identity ()
103
119
104
120
# combine decoder keyword arguments
105
- kwargs = dict (use_batchnorm = use_batchnorm , attention_type = attention_type )
106
- blocks = [
107
- DecoderBlock (in_ch , skip_ch , out_ch , ** kwargs )
108
- for in_ch , skip_ch , out_ch in zip (in_channels , skip_channels , out_channels )
109
- ]
110
- self .blocks = nn .ModuleList (blocks )
111
-
112
- def forward (self , * features ):
121
+ self .blocks = nn .ModuleList ()
122
+ for block_in_channels , block_skip_channels , block_out_channels in zip (
123
+ in_channels , skip_channels , out_channels
124
+ ):
125
+ block = DecoderBlock (
126
+ block_in_channels ,
127
+ block_skip_channels ,
128
+ block_out_channels ,
129
+ use_batchnorm = use_batchnorm ,
130
+ attention_type = attention_type ,
131
+ )
132
+ self .blocks .append (block )
133
+
134
+ def forward (self , * features : torch .Tensor ) -> torch .Tensor :
135
+ # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
136
+ spatial_shapes = [feature .shape [2 :] for feature in features ]
137
+ spatial_shapes = spatial_shapes [::- 1 ]
138
+
113
139
features = features [1 :] # remove first skip with same spatial resolution
114
140
features = features [::- 1 ] # reverse channels to start from head of encoder
115
141
116
142
head = features [0 ]
117
- skips = features [1 :]
143
+ skip_connections = features [1 :]
118
144
119
145
x = self .center (head )
146
+
120
147
for i , decoder_block in enumerate (self .blocks ):
121
- skip = skips [i ] if i < len (skips ) else None
122
- x = decoder_block (x , skip )
148
+ # upsample to the next spatial shape
149
+ height , width = spatial_shapes [i + 1 ]
150
+ skip_connection = skip_connections [i ] if i < len (skip_connections ) else None
151
+ x = decoder_block (x , height , width , skip_connection = skip_connection )
123
152
124
153
return x
0 commit comments