Skip to content

Commit 2b39afb

Browse files
authored
Add optional final upsampling for FPN (#72)
1 parent 127f930 commit 2b39afb

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

segmentation_models_pytorch/fpn/decoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ def __init__(
6565
encoder_channels,
6666
pyramid_channels=256,
6767
segmentation_channels=128,
68+
final_upsampling=4,
6869
final_channels=1,
6970
dropout=0.2,
7071
):
7172
super().__init__()
72-
73+
self.final_upsampling = final_upsampling
7374
self.conv1 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=(1, 1))
7475

7576
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
@@ -104,5 +105,6 @@ def forward(self, x):
104105
x = self.dropout(x)
105106
x = self.final_conv(x)
106107

107-
x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True)
108+
if self.final_upsampling is not None and self.final_upsampling > 1:
109+
x = F.interpolate(x, scale_factor=self.final_upsampling, mode='bilinear', align_corners=True)
108110
return x

segmentation_models_pytorch/fpn/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ class FPN(EncoderDecoder):
1515
dropout: spatial dropout rate in range (0, 1).
1616
activation: activation function used in ``.predict(x)`` method for inference.
1717
One of [``sigmoid``, ``softmax``, callable, None]
18+
final_upsampling: optional, final upsampling factor
19+
(default is 4 to preserve input -> output spatial shape identity)
1820
1921
Returns:
2022
``torch.nn.Module``: **FPN**
@@ -33,6 +35,7 @@ def __init__(
3335
classes=1,
3436
dropout=0.2,
3537
activation='sigmoid',
38+
final_upsampling=4,
3639
):
3740
encoder = get_encoder(
3841
encoder_name,
@@ -45,6 +48,7 @@ def __init__(
4548
segmentation_channels=decoder_segmentation_channels,
4649
final_channels=classes,
4750
dropout=dropout,
51+
final_upsampling=final_upsampling,
4852
)
4953

5054
super().__init__(encoder, decoder, activation)

0 commit comments

Comments
 (0)