Skip to content

Commit aa1e005

Browse files
committed
update decoder
1 parent 8d15c28 commit aa1e005

File tree

1 file changed

+27
-14
lines changed
  • segmentation_models_pytorch/decoders/pan

1 file changed

+27
-14
lines changed

segmentation_models_pytorch/decoders/pan/decoder.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,28 +195,41 @@ def __init__(
195195
)
196196
)
197197

198-
encoder_channels = encoder_channels[2:][::-1]
198+
encoder_channels = encoder_channels[2:]
199199

200200
self.fpa = FPABlock(
201-
in_channels=encoder_channels[0], out_channels=decoder_channels
201+
in_channels=encoder_channels[-1], out_channels=decoder_channels
202202
)
203203

204-
for i in range(1, len(encoder_channels)):
205-
self.add_module(
206-
f"gau{len(encoder_channels)-i}",
207-
GAUBlock(
208-
in_channels=encoder_channels[i],
209-
out_channels=decoder_channels,
210-
upscale_mode=upscale_mode,
211-
),
204+
if encoder_depth == 5:
205+
self.gau3 = GAUBlock(
206+
in_channels=encoder_channels[2],
207+
out_channels=decoder_channels,
208+
upscale_mode=upscale_mode,
209+
)
210+
if encoder_depth >= 4:
211+
self.gau2 = GAUBlock(
212+
in_channels=encoder_channels[1],
213+
out_channels=decoder_channels,
214+
upscale_mode=upscale_mode,
215+
)
216+
if encoder_depth >= 3:
217+
self.gau1 = GAUBlock(
218+
in_channels=encoder_channels[0],
219+
out_channels=decoder_channels,
220+
upscale_mode=upscale_mode,
212221
)
213222

214223
def forward(self, *features):
215224
features = features[2:] # remove first and second skip
216-
features = features[::-1] # reverse channels to start from head of encoder
217225

218-
out = self.fpa(features[0])
226+
out = self.fpa(features[-1]) # 1/16 or 1/32
227+
228+
if hasattr(self, "gau3"):
229+
out = self.gau3(features[2], out) # 1/16
230+
if hasattr(self, "gau2"):
231+
out = self.gau2(features[1], out) # 1/8
232+
if hasattr(self, "gau1"):
233+
out = self.gau1(features[0], out) # 1/4
219234

220-
for i in range(1, len(features)):
221-
out = getattr(self, f"gau{len(features)-i}")(features[i], out)
222235
return out

0 commit comments

Comments
 (0)