@@ -71,6 +71,7 @@ class DeepLabV3PlusDecoder(nn.Module):
7171 def __init__ (
7272 self ,
7373 encoder_channels : Sequence [int , ...],
74+ encoder_depth : Literal [3 , 4 , 5 ],
7475 out_channels : int ,
7576 atrous_rates : Iterable [int ],
7677 output_stride : Literal [8 , 16 ],
@@ -104,7 +105,14 @@ def __init__(
104105 scale_factor = 2 if output_stride == 8 else 4
105106 self .up = nn .UpsamplingBilinear2d (scale_factor = scale_factor )
106107
107- highres_in_channels = encoder_channels [- 4 ]
108+ if encoder_depth == 3 and output_stride == 8 :
109+ self .highres_input_index = - 2
110+ elif encoder_depth == 3 or encoder_depth == 4 :
111+ self .highres_input_index = - 3
112+ else :
113+ self .highres_input_index = - 4
114+
115+ highres_in_channels = encoder_channels [self .highres_input_index ]
108116 highres_out_channels = 48 # proposed by authors of paper
109117 self .block1 = nn .Sequential (
110118 nn .Conv2d (
@@ -128,7 +136,7 @@ def __init__(
128136 def forward (self , * features ):
129137 aspp_features = self .aspp (features [- 1 ])
130138 aspp_features = self .up (aspp_features )
131- high_res_features = self .block1 (features [- 4 ])
139+ high_res_features = self .block1 (features [self . highres_input_index ])
132140 concat_features = torch .cat ([aspp_features , high_res_features ], dim = 1 )
133141 fused_features = self .block2 (concat_features )
134142 return fused_features
0 commit comments