@@ -195,28 +195,41 @@ def __init__(
195195 )
196196 )
197197
198- encoder_channels = encoder_channels [2 :][:: - 1 ]
198+ encoder_channels = encoder_channels [2 :]
199199
200200 self .fpa = FPABlock (
201- in_channels = encoder_channels [0 ], out_channels = decoder_channels
201+ in_channels = encoder_channels [- 1 ], out_channels = decoder_channels
202202 )
203203
204- for i in range (1 , len (encoder_channels )):
205- self .add_module (
206- f"gau{ len (encoder_channels )- i } " ,
207- GAUBlock (
208- in_channels = encoder_channels [i ],
209- out_channels = decoder_channels ,
210- upscale_mode = upscale_mode ,
211- ),
204+ if encoder_depth == 5 :
205+ self .gau3 = GAUBlock (
206+ in_channels = encoder_channels [2 ],
207+ out_channels = decoder_channels ,
208+ upscale_mode = upscale_mode ,
209+ )
210+ if encoder_depth >= 4 :
211+ self .gau2 = GAUBlock (
212+ in_channels = encoder_channels [1 ],
213+ out_channels = decoder_channels ,
214+ upscale_mode = upscale_mode ,
215+ )
216+ if encoder_depth >= 3 :
217+ self .gau1 = GAUBlock (
218+ in_channels = encoder_channels [0 ],
219+ out_channels = decoder_channels ,
220+ upscale_mode = upscale_mode ,
212221 )
213222
214223 def forward (self , * features ):
215224 features = features [2 :] # remove first and second skip
216- features = features [::- 1 ] # reverse channels to start from head of encoder
217225
218- out = self .fpa (features [0 ])
226+ out = self .fpa (features [- 1 ]) # 1/16 or 1/32
227+
228+ if hasattr (self , "gau3" ):
229+ out = self .gau3 (features [2 ], out ) # 1/16
230+ if hasattr (self , "gau2" ):
231+ out = self .gau2 (features [1 ], out ) # 1/8
232+ if hasattr (self , "gau1" ):
233+ out = self .gau1 (features [0 ], out ) # 1/4
219234
220- for i in range (1 , len (features )):
221- out = getattr (self , f"gau{ len (features )- i } " )(features [i ], out )
222235 return out
0 commit comments