Skip to content

Commit 3179751

Browse files
committed
fix issue #377
1 parent edadc0d commit 3179751

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class DeepLabV3PlusDecoder(nn.Module):
7171
def __init__(
7272
self,
7373
encoder_channels: Sequence[int, ...],
74+
encoder_depth: Literal[3, 4, 5],
7475
out_channels: int,
7576
atrous_rates: Iterable[int],
7677
output_stride: Literal[8, 16],
@@ -104,7 +105,14 @@ def __init__(
104105
scale_factor = 2 if output_stride == 8 else 4
105106
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
106107

107-
highres_in_channels = encoder_channels[-4]
108+
if encoder_depth == 3 and output_stride == 8:
109+
self.highres_input_index = -2
110+
elif encoder_depth == 3 or encoder_depth == 4:
111+
self.highres_input_index = -3
112+
else:
113+
self.highres_input_index = -4
114+
115+
highres_in_channels = encoder_channels[self.highres_input_index]
108116
highres_out_channels = 48 # proposed by authors of paper
109117
self.block1 = nn.Sequential(
110118
nn.Conv2d(
@@ -128,7 +136,7 @@ def __init__(
128136
def forward(self, *features):
129137
aspp_features = self.aspp(features[-1])
130138
aspp_features = self.up(aspp_features)
131-
high_res_features = self.block1(features[-4])
139+
high_res_features = self.block1(features[self.highres_input_index])
132140
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
133141
fused_features = self.block2(concat_features)
134142
return fused_features

segmentation_models_pytorch/decoders/deeplabv3/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class DeepLabV3Plus(SegmentationModel):
150150
def __init__(
151151
self,
152152
encoder_name: str = "resnet34",
153-
encoder_depth: int = 5,
153+
encoder_depth: Literal[3, 4, 5] = 5,
154154
encoder_weights: Optional[str] = "imagenet",
155155
encoder_output_stride: Literal[8, 16] = 16,
156156
decoder_channels: int = 256,
@@ -177,6 +177,7 @@ def __init__(
177177

178178
self.decoder = DeepLabV3PlusDecoder(
179179
encoder_channels=self.encoder.out_channels,
180+
encoder_depth=encoder_depth,
180181
out_channels=decoder_channels,
181182
atrous_rates=decoder_atrous_rates,
182183
output_stride=encoder_output_stride,

0 commit comments

Comments
 (0)