Skip to content

Commit 92f3ae4

Browse files
committed
Merge branch 'upgrade-fpn'
2 parents 03be12a + 94636d7 commit 92f3ae4

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

segmentation_models_pytorch/fpn/decoder.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,29 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from ..common.blocks import Conv2dReLU
65
from ..base.model import Model
76

87

8+
class Conv3x3GNReLU(nn.Module):
9+
def __init__(self, in_channels, out_channels, upsample=False):
10+
11+
super().__init__()
12+
self.upsample = upsample
13+
self.block = nn.Sequential(
14+
nn.Conv2d(in_channels, out_channels, (3, 3),
15+
stride=1, padding=1, bias=False),
16+
nn.GroupNorm(32, out_channels),
17+
nn.ReLU(inplace=True),
18+
)
19+
20+
def forward(self, x):
21+
x = self.block(x)
22+
if self.upsample:
23+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
24+
return x
25+
26+
27+
928
class FPNBlock(nn.Module):
1029
def __init__(self, pyramid_channels, skip_channels):
1130
super().__init__()
@@ -22,12 +41,18 @@ def forward(self, x):
2241

2342

2443
class SegmentationBlock(nn.Module):
25-
def __init__(self, in_channels, out_channels, use_batchnorm=True):
44+
def __init__(self, in_channels, out_channels, n_upsamples=0):
2645
super().__init__()
27-
self.block = nn.Sequential(
28-
Conv2dReLU(in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
29-
Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
30-
)
46+
47+
blocks = [
48+
Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))
49+
]
50+
51+
if n_upsamples > 1:
52+
for _ in range(1, n_upsamples):
53+
blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
54+
55+
self.block = nn.Sequential(*blocks)
3156

3257
def forward(self, x):
3358
return self.block(x)
@@ -42,7 +67,6 @@ def __init__(
4267
segmentation_channels=128,
4368
final_channels=1,
4469
dropout=0.2,
45-
use_batchnorm=True,
4670
):
4771
super().__init__()
4872

@@ -52,13 +76,13 @@ def __init__(
5276
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
5377
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
5478

55-
self.s5 = SegmentationBlock(pyramid_channels, segmentation_channels, use_batchnorm=use_batchnorm)
56-
self.s4 = SegmentationBlock(pyramid_channels, segmentation_channels, use_batchnorm=use_batchnorm)
57-
self.s3 = SegmentationBlock(pyramid_channels, segmentation_channels, use_batchnorm=use_batchnorm)
58-
self.s2 = SegmentationBlock(pyramid_channels, segmentation_channels, use_batchnorm=use_batchnorm)
79+
self.s5 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=3)
80+
self.s4 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=2)
81+
self.s3 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=1)
82+
self.s2 = SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=0)
5983

6084
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
61-
self.final_conv = nn.Conv2d(4 * segmentation_channels, final_channels, kernel_size=3, padding=1)
85+
self.final_conv = nn.Conv2d(segmentation_channels, final_channels, kernel_size=1, padding=0)
6286

6387
self.initialize()
6488

@@ -75,12 +99,7 @@ def forward(self, x):
7599
s3 = self.s3(p3)
76100
s2 = self.s2(p2)
77101

78-
x = torch.cat([
79-
F.interpolate(s5, scale_factor=8, mode='bilinear', align_corners=True),
80-
F.interpolate(s4, scale_factor=4, mode='bilinear', align_corners=True),
81-
F.interpolate(s3, scale_factor=2, mode='bilinear', align_corners=True),
82-
s2,
83-
], dim=1)
102+
x = s5 + s4 + s3 + s2
84103

85104
x = self.dropout(x)
86105
x = self.final_conv(x)

segmentation_models_pytorch/fpn/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ def __init__(
1111
encoder_weights='imagenet',
1212
decoder_pyramid_channels=256,
1313
decoder_segmenation_channels=128,
14-
decoder_use_batchnorm=True,
1514
classes=1,
1615
dropout=0.2,
1716
activation='sigmoid',
@@ -28,7 +27,6 @@ def __init__(
2827
segmentation_channels=decoder_segmenation_channels,
2928
final_channels=classes,
3029
dropout=dropout,
31-
use_batchnorm=decoder_use_batchnorm,
3230
)
3331

3432
super().__init__(encoder, decoder, activation)

0 commit comments

Comments
 (0)