Skip to content

Commit 5b105a8

Browse files
laol777qubvel
authored andcommitted
Add scSE attention module for Unet (#53)
* added scSE module * ability to choose type of attention for unet decoder * update docstring
1 parent f70502e commit 5b105a8

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

segmentation_models_pytorch/common/blocks.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,19 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0,
2020

2121
def forward(self, x):
2222
return self.block(x)
23+
24+
25+
class SCSEModule(nn.Module):
26+
def __init__(self, ch, re=16):
27+
super().__init__()
28+
self.cSE = nn.Sequential(nn.AdaptiveAvgPool2d(1),
29+
nn.Conv2d(ch, ch//re, 1),
30+
nn.ReLU(inplace=True),
31+
nn.Conv2d(ch//re, ch, 1),
32+
nn.Sigmoid()
33+
)
34+
self.sSE = nn.Sequential(nn.Conv2d(ch, ch, 1),
35+
nn.Sigmoid())
36+
37+
def forward(self, x):
38+
return x * self.cSE(x) + x * self.sSE(x)

segmentation_models_pytorch/unet/decoder.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from ..common.blocks import Conv2dReLU
5+
from ..common.blocks import Conv2dReLU, SCSEModule
66
from ..base.model import Model
77

88

99
class DecoderBlock(nn.Module):
10-
def __init__(self, in_channels, out_channels, use_batchnorm=True):
10+
def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None):
1111
super().__init__()
12+
if attention_type is None:
13+
self.attention1 = nn.Identity()
14+
self.attention2 = nn.Identity()
15+
elif attention_type == 'scse':
16+
self.attention1 = SCSEModule(in_channels)
17+
self.attention2 = SCSEModule(out_channels)
18+
1219
self.block = nn.Sequential(
1320
Conv2dReLU(in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
1421
Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
@@ -19,7 +26,10 @@ def forward(self, x):
1926
x = F.interpolate(x, scale_factor=2, mode='nearest')
2027
if skip is not None:
2128
x = torch.cat([x, skip], dim=1)
29+
x = self.attention1(x)
30+
2231
x = self.block(x)
32+
x = self.attention2(x)
2333
return x
2434

2535

@@ -38,6 +48,7 @@ def __init__(
3848
final_channels=1,
3949
use_batchnorm=True,
4050
center=False,
51+
attention_type=None
4152
):
4253
super().__init__()
4354

@@ -50,11 +61,16 @@ def __init__(
5061
in_channels = self.compute_channels(encoder_channels, decoder_channels)
5162
out_channels = decoder_channels
5263

53-
self.layer1 = DecoderBlock(in_channels[0], out_channels[0], use_batchnorm=use_batchnorm)
54-
self.layer2 = DecoderBlock(in_channels[1], out_channels[1], use_batchnorm=use_batchnorm)
55-
self.layer3 = DecoderBlock(in_channels[2], out_channels[2], use_batchnorm=use_batchnorm)
56-
self.layer4 = DecoderBlock(in_channels[3], out_channels[3], use_batchnorm=use_batchnorm)
57-
self.layer5 = DecoderBlock(in_channels[4], out_channels[4], use_batchnorm=use_batchnorm)
64+
self.layer1 = DecoderBlock(in_channels[0], out_channels[0],
65+
use_batchnorm=use_batchnorm, attention_type=attention_type)
66+
self.layer2 = DecoderBlock(in_channels[1], out_channels[1],
67+
use_batchnorm=use_batchnorm, attention_type=attention_type)
68+
self.layer3 = DecoderBlock(in_channels[2], out_channels[2],
69+
use_batchnorm=use_batchnorm, attention_type=attention_type)
70+
self.layer4 = DecoderBlock(in_channels[3], out_channels[3],
71+
use_batchnorm=use_batchnorm, attention_type=attention_type)
72+
self.layer5 = DecoderBlock(in_channels[4], out_channels[4],
73+
use_batchnorm=use_batchnorm, attention_type=attention_type)
5874
self.final_conv = nn.Conv2d(out_channels[4], final_channels, kernel_size=(1, 1))
5975

6076
self.initialize()

segmentation_models_pytorch/unet/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class Unet(EncoderDecoder):
1717
activation: activation function used in ``.predict(x)`` method for inference.
1818
One of [``sigmoid``, ``softmax``, callable, None]
1919
center: if ``True`` add ``Conv2dReLU`` block on encoder head (useful for VGG models)
20+
attention_type: attention module used in decoder of the model
21+
One of [``None``, ``scse``]
2022
2123
Returns:
2224
``torch.nn.Module``: **Unet**
@@ -35,6 +37,7 @@ def __init__(
3537
classes=1,
3638
activation='sigmoid',
3739
center=False, # usefull for VGG models
40+
attention_type=None
3841
):
3942
encoder = get_encoder(
4043
encoder_name,
@@ -47,6 +50,7 @@ def __init__(
4750
final_channels=classes,
4851
use_batchnorm=decoder_use_batchnorm,
4952
center=center,
53+
attention_type=attention_type
5054
)
5155

5256
super().__init__(encoder, decoder, activation)

0 commit comments

Comments
 (0)