Skip to content

Commit c116426

Browse files
SiarheiFedartsouqubvel
authored andcommitted
Add ability to use concatenation and addition in FPN (#82)
1 parent f249c81 commit c116426

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

segmentation_models_pytorch/fpn/decoder.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,14 @@ def __init__(
6868
final_upsampling=4,
6969
final_channels=1,
7070
dropout=0.2,
71+
merge_policy='add'
7172
):
7273
super().__init__()
74+
75+
if merge_policy not in ['add', 'cat']:
76+
raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(merge_policy))
77+
self.merge_policy = merge_policy
78+
7379
self.final_upsampling = final_upsampling
7480
self.conv1 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=(1, 1))
7581

@@ -83,6 +89,10 @@ def __init__(
8389
self.s2 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=0)
8490

8591
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
92+
93+
if self.merge_policy == 'cat':
94+
segmentation_channels *= 4
95+
8696
self.final_conv = nn.Conv2d(segmentation_channels, final_channels, kernel_size=1, padding=0)
8797

8898
self.initialize()
@@ -100,7 +110,10 @@ def forward(self, x):
100110
s3 = self.s3(p3)
101111
s2 = self.s2(p2)
102112

103-
x = s5 + s4 + s3 + s2
113+
if self.merge_policy == 'add':
114+
x = s5 + s4 + s3 + s2
115+
elif self.merge_policy == 'cat':
116+
x = torch.cat([s5, s4, s3, s2], dim=1)
104117

105118
x = self.dropout(x)
106119
x = self.final_conv(x)

segmentation_models_pytorch/fpn/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class FPN(EncoderDecoder):
1717
One of [``sigmoid``, ``softmax``, callable, None]
1818
final_upsampling: optional, final upsampling factor
1919
(default is 4 to preserve input -> output spatial shape identity)
20-
20+
decoder_merge_policy: determines how to merge outputs inside FPN.
21+
One of [``add``, ``cat``]
2122
Returns:
2223
``torch.nn.Module``: **FPN**
2324
@@ -36,6 +37,7 @@ def __init__(
3637
dropout=0.2,
3738
activation='sigmoid',
3839
final_upsampling=4,
40+
decoder_merge_policy='add'
3941
):
4042
encoder = get_encoder(
4143
encoder_name,
@@ -49,6 +51,7 @@ def __init__(
4951
final_channels=classes,
5052
dropout=dropout,
5153
final_upsampling=final_upsampling,
54+
merge_policy=decoder_merge_policy
5255
)
5356

5457
super().__init__(encoder, decoder, activation)

tests/test_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def test_fpn(encoder_name):
6565
_test_forward_backward(smp.FPN, encoder_name)
6666
_test_pretrained_model(smp.FPN, encoder_name, get_pretrained_weights_name(encoder_name))
6767

68+
from functools import partial
69+
_test_forward_backward(partial(smp.FPN, decoder_merge_policy='cat'), encoder_name)
70+
_test_pretrained_model(partial(smp.FPN, decoder_merge_policy='cat'), encoder_name, get_pretrained_weights_name(encoder_name))
71+
6872

6973
@pytest.mark.parametrize('encoder_name', _select_names(ENCODERS, k=1))
7074
def test_linknet(encoder_name):

0 commit comments

Comments
 (0)