Skip to content

Commit b83e000

Browse files
SiarheiFedartsouqubvel
authored andcommitted
Add ability to use Activated BatchNorm in decoder. (#81)
1 parent c116426 commit b83e000

File tree

5 files changed

+39
-17
lines changed

5 files changed

+39
-17
lines changed

segmentation_models_pytorch/common/blocks.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,30 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0,
88
super().__init__()
99

1010
layers = [
11-
nn.Conv2d(in_channels, out_channels, kernel_size,
12-
stride=stride, padding=padding, bias=not (use_batchnorm)),
13-
nn.ReLU(inplace=True),
11+
nn.Conv2d(
12+
in_channels,
13+
out_channels,
14+
kernel_size,
15+
stride=stride,
16+
padding=padding,
17+
bias=not (use_batchnorm)
18+
)
1419
]
1520

16-
if use_batchnorm:
17-
layers.insert(1, nn.BatchNorm2d(out_channels, **batchnorm_params))
18-
21+
if use_batchnorm == 'inplace':
22+
try:
23+
from inplace_abn import InPlaceABN
24+
except ImportError:
25+
raise RuntimeError("In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. To install see: https://github.com/mapillary/inplace_abn")
26+
27+
layers.append(InPlaceABN(out_channels, activation='leaky_relu', activation_param=0.0, **batchnorm_params))
28+
elif use_batchnorm:
29+
layers.append(nn.BatchNorm2d(out_channels, **batchnorm_params))
30+
layers.append(nn.ReLU(inplace=True))
31+
else:
32+
layers.append(nn.ReLU(inplace=True))
33+
34+
1935
self.block = nn.Sequential(*layers)
2036

2137
def forward(self, x):

segmentation_models_pytorch/linknet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ class Linknet(EncoderDecoder):
1414
extractor to build segmentation model.
1515
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
1616
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
17-
is used.
17+
is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption.
18+
One of [True, False, 'inplace']
1819
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
1920
activation: activation function used in ``.predict(x)`` method for inference.
2021
One of [``sigmoid``, ``softmax``, callable, None]
21-
2222
Returns:
2323
``torch.nn.Module``: **Linknet**
2424

segmentation_models_pytorch/pspnet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ class PSPNet(EncoderDecoder):
1414
to construct PSP module on it.
1515
psp_out_channels: number of filters in PSP block.
1616
psp_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
17-
is used.
17+
is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption.
18+
One of [True, False, 'inplace']
1819
psp_aux_output: if ``True`` add auxiliary classification output for encoder training
1920
psp_dropout: spatial dropout rate between 0 and 1.
2021
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
2122
activation: activation function used in ``.predict(x)`` method for inference.
2223
One of [``sigmoid``, ``softmax``, callable, None]
23-
2424
Returns:
2525
``torch.nn.Module``: **PSPNet**
2626

segmentation_models_pytorch/unet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ class Unet(EncoderDecoder):
1212
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
1313
decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
1414
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
15-
is used.
15+
is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption.
16+
One of [True, False, 'inplace']
1617
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
1718
activation: activation function used in ``.predict(x)`` method for inference.
1819
One of [``sigmoid``, ``softmax``, callable, None]
1920
center: if ``True`` add ``Conv2dReLU`` block on encoder head (useful for VGG models)
2021
attention_type: attention module used in decoder of the model
2122
One of [``None``, ``scse``]
22-
2323
Returns:
2424
``torch.nn.Module``: **Unet**
2525

tests/test_models.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch
66
import random
7+
import importlib
78

89
# mock detection module
910
sys.modules['torchvision._C'] = mock.Mock()
@@ -35,18 +36,17 @@ def _select_names(names, k=2):
3536
return names
3637

3738

38-
def _test_forward_backward(model_fn, encoder_name):
39-
40-
model = model_fn(encoder_name, encoder_weights=None)
39+
def _test_forward_backward(model_fn, encoder_name, **model_params):
40+
model = model_fn(encoder_name, encoder_weights=None, **model_params)
4141

4242
x = torch.ones((1, 3, 64, 64))
4343
y = model.forward(x)
4444
l = y.mean()
4545
l.backward()
4646

4747

48-
def _test_pretrained_model(model_fn, encoder_name, encoder_weights):
49-
model = model_fn(encoder_name, encoder_weights=encoder_weights)
48+
def _test_pretrained_model(model_fn, encoder_name, encoder_weights, **model_params):
49+
model = model_fn(encoder_name, encoder_weights=encoder_weights, **model_params)
5050

5151
x = torch.ones((1, 3, 64, 64))
5252
y = model.predict(x)
@@ -82,5 +82,11 @@ def test_pspnet(encoder_name):
8282
_test_pretrained_model(smp.PSPNet, encoder_name, get_pretrained_weights_name(encoder_name))
8383

8484

85+
@pytest.mark.skipif(importlib.util.find_spec('inplace_abn') is None, reason='')
86+
def test_inplace_abn():
87+
_test_forward_backward(smp.Unet, 'resnet18', decoder_use_batchnorm='inplace')
88+
_test_pretrained_model(smp.Unet, 'resnet18', get_pretrained_weights_name('resnet18'), decoder_use_batchnorm='inplace')
89+
90+
8591
if __name__ == '__main__':
8692
pytest.main([__file__])

0 commit comments

Comments
 (0)