@@ -175,34 +175,40 @@ def forward(self, x, y):
175175
176176class PANDecoder (nn .Module ):
177177 def __init__ (
178- self , encoder_channels , decoder_channels , upscale_mode : str = "bilinear"
178+ self ,
179+ encoder_channels ,
180+ encoder_depth ,
181+ decoder_channels ,
182+ upscale_mode : str = "bilinear" ,
179183 ):
180184 super ().__init__ ()
181185
186+ if encoder_depth < 3 :
187+ raise ValueError (
188+ "Encoder depth for PAN decoder cannot be less than 3, got {}." .format (
189+ encoder_depth
190+ )
191+ )
192+
193+ encoder_channels = encoder_channels [2 :][::- 1 ]
194+
182195 self .fpa = FPABlock (
183- in_channels = encoder_channels [- 1 ], out_channels = decoder_channels
184- )
185- self .gau3 = GAUBlock (
186- in_channels = encoder_channels [- 2 ],
187- out_channels = decoder_channels ,
188- upscale_mode = upscale_mode ,
189- )
190- self .gau2 = GAUBlock (
191- in_channels = encoder_channels [- 3 ],
192- out_channels = decoder_channels ,
193- upscale_mode = upscale_mode ,
194- )
195- self .gau1 = GAUBlock (
196- in_channels = encoder_channels [- 4 ],
197- out_channels = decoder_channels ,
198- upscale_mode = upscale_mode ,
196+ in_channels = encoder_channels [0 ], out_channels = decoder_channels
199197 )
200198
199+ for i in range (1 , len (encoder_channels )):
200+ self .add_module (f"gau{ len (encoder_channels )- i } " , GAUBlock (
201+ in_channels = encoder_channels [i ],
202+ out_channels = decoder_channels ,
203+ upscale_mode = upscale_mode ,
204+ ))
205+
201206 def forward (self , * features ):
202- bottleneck = features [- 1 ]
203- x5 = self .fpa (bottleneck ) # 1/32
204- x4 = self .gau3 (features [- 2 ], x5 ) # 1/16
205- x3 = self .gau2 (features [- 3 ], x4 ) # 1/8
206- x2 = self .gau1 (features [- 4 ], x3 ) # 1/4
207+ features = features [2 :] # remove first and second skip
208+ features = features [::- 1 ] # reverse channels to start from head of encoder
209+
210+ out = self .fpa (features [0 ])
207211
208- return x2
212+ for i in range (1 , len (features )):
213+ out = getattr (self , f"gau{ len (features )- i } " )(features [i ], out )
214+ return out
0 commit comments