Skip to content

Commit e8a1825

Browse files
committed
update PAN decoder support encoder depth
1 parent 9389d8e commit e8a1825

File tree

1 file changed

+29
-23
lines changed
  • segmentation_models_pytorch/decoders/pan

1 file changed

+29
-23
lines changed

segmentation_models_pytorch/decoders/pan/decoder.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -175,34 +175,40 @@ def forward(self, x, y):
175175

176176
class PANDecoder(nn.Module):
177177
def __init__(
178-
self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear"
178+
self,
179+
encoder_channels,
180+
encoder_depth,
181+
decoder_channels,
182+
upscale_mode: str = "bilinear",
179183
):
180184
super().__init__()
181185

186+
if encoder_depth < 3:
187+
raise ValueError(
188+
"Encoder depth for PAN decoder cannot be less than 3, got {}.".format(
189+
encoder_depth
190+
)
191+
)
192+
193+
encoder_channels = encoder_channels[2:][::-1]
194+
182195
self.fpa = FPABlock(
183-
in_channels=encoder_channels[-1], out_channels=decoder_channels
184-
)
185-
self.gau3 = GAUBlock(
186-
in_channels=encoder_channels[-2],
187-
out_channels=decoder_channels,
188-
upscale_mode=upscale_mode,
189-
)
190-
self.gau2 = GAUBlock(
191-
in_channels=encoder_channels[-3],
192-
out_channels=decoder_channels,
193-
upscale_mode=upscale_mode,
194-
)
195-
self.gau1 = GAUBlock(
196-
in_channels=encoder_channels[-4],
197-
out_channels=decoder_channels,
198-
upscale_mode=upscale_mode,
196+
in_channels=encoder_channels[0], out_channels=decoder_channels
199197
)
200198

199+
for i in range(1, len(encoder_channels)):
200+
self.add_module(f"gau{len(encoder_channels)-i}", GAUBlock(
201+
in_channels=encoder_channels[i],
202+
out_channels=decoder_channels,
203+
upscale_mode=upscale_mode,
204+
))
205+
201206
def forward(self, *features):
202-
bottleneck = features[-1]
203-
x5 = self.fpa(bottleneck) # 1/32
204-
x4 = self.gau3(features[-2], x5) # 1/16
205-
x3 = self.gau2(features[-3], x4) # 1/8
206-
x2 = self.gau1(features[-4], x3) # 1/4
207+
features = features[2:] # remove first and second skip
208+
features = features[::-1] # reverse channels to start from head of encoder
209+
210+
out = self.fpa(features[0])
207211

208-
return x2
212+
for i in range(1, len(features)):
213+
out = getattr(self, f"gau{len(features)-i}")(features[i], out)
214+
return out

0 commit comments

Comments
 (0)