Skip to content

Commit 4bd8d38

Browse files
authored
add unet++ (#279)
* add unet++ * update README.md * update tests for unet++ * fixed test behaviour for unet++
1 parent 05012be commit 4bd8d38

File tree

6 files changed

+234
-6
lines changed

6 files changed

+234
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
6363
### Models <a name="models"></a>
6464

6565
#### Architectures <a name="architectires"></a>
66-
- [Unet](https://arxiv.org/abs/1505.04597)
66+
- [Unet](https://arxiv.org/abs/1505.04597) and [Unet++](https://arxiv.org/pdf/1807.10165.pdf)
6767
- [Linknet](https://arxiv.org/abs/1707.03718)
6868
- [FPN](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)
6969
- [PSPNet](https://arxiv.org/abs/1612.01105)

segmentation_models_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .unet import Unet
2+
from .unetplusplus import UnetPlusPlus
23
from .linknet import Linknet
34
from .fpn import FPN
45
from .pspnet import PSPNet
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import UnetPlusPlus
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from ..base import modules as md
6+
7+
8+
class DecoderBlock(nn.Module):
9+
def __init__(
10+
self,
11+
in_channels,
12+
skip_channels,
13+
out_channels,
14+
use_batchnorm=True,
15+
attention_type=None,
16+
):
17+
super().__init__()
18+
self.conv1 = md.Conv2dReLU(
19+
in_channels + skip_channels,
20+
out_channels,
21+
kernel_size=3,
22+
padding=1,
23+
use_batchnorm=use_batchnorm,
24+
)
25+
self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
26+
self.conv2 = md.Conv2dReLU(
27+
out_channels,
28+
out_channels,
29+
kernel_size=3,
30+
padding=1,
31+
use_batchnorm=use_batchnorm,
32+
)
33+
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
34+
35+
def forward(self, x, skip=None):
36+
x = F.interpolate(x, scale_factor=2, mode="nearest")
37+
if skip is not None:
38+
x = torch.cat([x, skip], dim=1)
39+
x = self.attention1(x)
40+
x = self.conv1(x)
41+
x = self.conv2(x)
42+
x = self.attention2(x)
43+
return x
44+
45+
46+
class CenterBlock(nn.Sequential):
47+
def __init__(self, in_channels, out_channels, use_batchnorm=True):
48+
conv1 = md.Conv2dReLU(
49+
in_channels,
50+
out_channels,
51+
kernel_size=3,
52+
padding=1,
53+
use_batchnorm=use_batchnorm,
54+
)
55+
conv2 = md.Conv2dReLU(
56+
out_channels,
57+
out_channels,
58+
kernel_size=3,
59+
padding=1,
60+
use_batchnorm=use_batchnorm,
61+
)
62+
super().__init__(conv1, conv2)
63+
64+
65+
class UnetPlusPlusDecoder(nn.Module):
66+
def __init__(
67+
self,
68+
encoder_channels,
69+
decoder_channels,
70+
n_blocks=5,
71+
use_batchnorm=True,
72+
attention_type=None,
73+
center=False,
74+
):
75+
super().__init__()
76+
77+
if n_blocks != len(decoder_channels):
78+
raise ValueError(
79+
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
80+
n_blocks, len(decoder_channels)
81+
)
82+
)
83+
84+
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
85+
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
86+
# computing blocks input and output channels
87+
head_channels = encoder_channels[0]
88+
self.in_channels = [head_channels] + list(decoder_channels[:-1])
89+
self.skip_channels = list(encoder_channels[1:]) + [0]
90+
self.out_channels = decoder_channels
91+
if center:
92+
self.center = CenterBlock(
93+
head_channels, head_channels, use_batchnorm=use_batchnorm
94+
)
95+
else:
96+
self.center = nn.Identity()
97+
98+
# combine decoder keyword arguments
99+
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
100+
101+
blocks = {}
102+
for layer_idx in range(len(self.in_channels) - 1):
103+
for depth_idx in range(layer_idx+1):
104+
if depth_idx == 0:
105+
in_ch = self.in_channels[layer_idx]
106+
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1)
107+
out_ch = self.out_channels[layer_idx]
108+
else:
109+
out_ch = self.skip_channels[layer_idx]
110+
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1-depth_idx)
111+
in_ch = self.skip_channels[layer_idx - 1]
112+
blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
113+
blocks[f'x_{0}_{len(self.in_channels)-1}'] =\
114+
DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs)
115+
self.blocks = nn.ModuleDict(blocks)
116+
self.depth = len(self.in_channels) - 1
117+
118+
def forward(self, *features):
119+
120+
features = features[1:] # remove first skip with same spatial resolution
121+
features = features[::-1] # reverse channels to start from head of encoder
122+
# start bulding dense connections
123+
dense_x = {}
124+
for layer_idx in range(len(self.in_channels)-1):
125+
for depth_idx in range(self.depth-layer_idx):
126+
if layer_idx == 0:
127+
output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1])
128+
dense_x[f'x_{depth_idx}_{depth_idx}'] = output
129+
else:
130+
dense_l_i = depth_idx + layer_idx
131+
cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)]
132+
cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1)
133+
dense_x[f'x_{depth_idx}_{dense_l_i}'] =\
134+
self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features)
135+
dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth-1}'])
136+
return dense_x[f'x_{0}_{self.depth}']
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from typing import Optional, Union, List
2+
from .decoder import UnetPlusPlusDecoder
3+
from ..encoders import get_encoder
4+
from ..base import SegmentationModel
5+
from ..base import SegmentationHead, ClassificationHead
6+
7+
8+
class UnetPlusPlus(SegmentationModel):
9+
"""Unet++_ is a fully convolution neural network for image semantic segmentation
10+
11+
Args:
12+
encoder_name: name of classification model (without last dense layers) used as feature
13+
extractor to build segmentation model.
14+
encoder_depth (int): number of stages used in decoder, larger depth - more features are generated.
15+
e.g. for depth=3 encoder will generate list of features with following spatial shapes
16+
[(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have
17+
spatial resolution (H/(2^depth), W/(2^depth)]
18+
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
19+
decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
20+
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
21+
is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption.
22+
One of [True, False, 'inplace']
23+
decoder_attention_type: attention module used in decoder of the model
24+
One of [``None``, ``scse``]
25+
in_channels: number of input channels for model, default is 3.
26+
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
27+
activation: activation function to apply after final convolution;
28+
One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None]
29+
aux_params: if specified model will have additional classification auxiliary output
30+
build on top of encoder, supported params:
31+
- classes (int): number of classes
32+
- pooling (str): one of 'max', 'avg'. Default is 'avg'.
33+
- dropout (float): dropout factor in [0, 1)
34+
- activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)
35+
36+
Returns:
37+
``torch.nn.Module``: **Unet++**
38+
39+
.. _UnetPlusPlus:
40+
https://arxiv.org/pdf/1807.10165.pdf
41+
42+
"""
43+
44+
def __init__(
45+
self,
46+
encoder_name: str = "resnet34",
47+
encoder_depth: int = 5,
48+
encoder_weights: str = "imagenet",
49+
decoder_use_batchnorm: bool = True,
50+
decoder_channels: List[int] = (256, 128, 64, 32, 16),
51+
decoder_attention_type: Optional[str] = None,
52+
in_channels: int = 3,
53+
classes: int = 1,
54+
activation: Optional[Union[str, callable]] = None,
55+
aux_params: Optional[dict] = None,
56+
):
57+
super().__init__()
58+
59+
self.encoder = get_encoder(
60+
encoder_name,
61+
in_channels=in_channels,
62+
depth=encoder_depth,
63+
weights=encoder_weights,
64+
)
65+
66+
self.decoder = UnetPlusPlusDecoder(
67+
encoder_channels=self.encoder.out_channels,
68+
decoder_channels=decoder_channels,
69+
n_blocks=encoder_depth,
70+
use_batchnorm=decoder_use_batchnorm,
71+
center=True if encoder_name.startswith("vgg") else False,
72+
attention_type=decoder_attention_type,
73+
)
74+
75+
self.segmentation_head = SegmentationHead(
76+
in_channels=decoder_channels[-1],
77+
out_channels=classes,
78+
activation=activation,
79+
kernel_size=3,
80+
)
81+
82+
if aux_params is not None:
83+
self.classification_head = ClassificationHead(
84+
in_channels=self.encoder.out_channels[-1], **aux_params
85+
)
86+
else:
87+
self.classification_head = None
88+
89+
self.name = "u-{}".format(encoder_name)
90+
self.initialize()

tests/test_models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_encoders():
3030

3131

3232
def get_sample(model_class):
33-
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet]:
33+
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus]:
3434
sample = torch.ones([1, 3, 64, 64])
3535
elif model_class == smp.PAN:
3636
sample = torch.ones([2, 3, 256, 256])
@@ -57,9 +57,9 @@ def _test_forward_backward(model, sample, test_shape=False):
5757

5858
@pytest.mark.parametrize("encoder_name", ENCODERS)
5959
@pytest.mark.parametrize("encoder_depth", [3, 5])
60-
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet])
60+
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
6161
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
62-
if model_class is smp.Unet:
62+
if model_class is smp.Unet or model_class is smp.UnetPlusPlus:
6363
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
6464
model = model_class(
6565
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
@@ -76,15 +76,15 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
7676

7777
@pytest.mark.parametrize(
7878
"model_class",
79-
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.DeepLabV3]
79+
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.DeepLabV3]
8080
)
8181
def test_forward_backward(model_class):
8282
sample = get_sample(model_class)
8383
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
8484
_test_forward_backward(model, sample)
8585

8686

87-
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet])
87+
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
8888
def test_aux_output(model_class):
8989
model = model_class(
9090
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)

0 commit comments

Comments
 (0)