Skip to content

Commit 8d1e565

Browse files
committed
update UPerNet decoder
Resize all FPN output features to 1/4 of the original resolution.
1 parent b636261 commit 8d1e565

File tree

1 file changed

+6
-2
lines changed
  • segmentation_models_pytorch/decoders/upernet

1 file changed

+6
-2
lines changed

segmentation_models_pytorch/decoders/upernet/decoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def __init__(
113113
)
114114

115115
def forward(self, *features):
116-
# Resize all FPN features to the size of the largest feature
117-
target_size = features[0].shape[2:]
116+
output_size = features[0].shape[2:]
117+
target_size = [size // 4 for size in output_size]
118118

119119
features = features[1:] # remove first skip with same spatial resolution
120120
features = features[::-1] # reverse channels to start from head of encoder
@@ -126,6 +126,7 @@ def forward(self, *features):
126126
fpn_feature = stage(fpn_features[-1], feature)
127127
fpn_features.append(fpn_feature)
128128

129+
# Resize all FPN features to 1/4 of the original resolution.
129130
resized_fpn_features = []
130131
for feature in fpn_features:
131132
resized_feature = F.interpolate(
@@ -134,5 +135,8 @@ def forward(self, *features):
134135
resized_fpn_features.append(resized_feature)
135136

136137
output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1))
138+
output = F.interpolate(
139+
output, size=output_size, mode="bilinear", align_corners=False
140+
)
137141

138142
return output

0 commit comments

Comments
 (0)