Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
The main features of this library are:

- High-level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 10 models architectures for binary and multi class segmentation (including legendary Unet)
- 124 available encoders (and 500+ encoders from [timm](https://github.com/rwightman/pytorch-image-models))
- All encoders have pre-trained weights for faster and better convergence
- Popular metrics and losses for training routines
Expand Down Expand Up @@ -94,6 +94,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
- PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)]
- DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)]
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
- UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)]

#### Encoders <a name="encoders"></a>

Expand Down
7 changes: 7 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,10 @@ MAnet
PAN
~~~
.. autoclass:: segmentation_models_pytorch.PAN


.. _upernet:

UPerNet
~~~
.. autoclass:: segmentation_models_pytorch.UPerNet
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .decoders.pspnet import PSPNet
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .base.hub_mixin import from_pretrained

from .__version__ import __version__
Expand Down Expand Up @@ -48,6 +49,7 @@ def create_model(
DeepLabV3,
DeepLabV3Plus,
PAN,
UPerNet,
]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
Expand Down Expand Up @@ -82,6 +84,7 @@ def create_model(
"DeepLabV3",
"DeepLabV3Plus",
"PAN",
"UPerNet",
"from_pretrained",
"create_model",
"__version__",
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/upernet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import UPerNet

__all__ = ["UPerNet"]
142 changes: 142 additions & 0 deletions segmentation_models_pytorch/decoders/upernet/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from segmentation_models_pytorch.base import modules as md


class PSPModule(nn.Module):
def __init__(
self,
in_channels,
out_channels,
sizes=(1, 2, 3, 6),
use_batchnorm=True,
):
super().__init__()
self.blocks = nn.ModuleList(
[
nn.Sequential(
nn.AdaptiveAvgPool2d(size),
md.Conv2dReLU(
in_channels,
in_channels // len(sizes),
kernel_size=1,
use_batchnorm=use_batchnorm,
),
)
for size in sizes
]
)
self.out_conv = md.Conv2dReLU(
in_channels=in_channels * 2,
out_channels=out_channels,
kernel_size=1,
use_batchnorm=True,
)

def forward(self, x):
_, _, height, weight = x.shape
out = [x] + [
F.interpolate(
block(x), size=(height, weight), mode="bilinear", align_corners=False
)
for block in self.blocks
]
out = self.out_conv(torch.cat(out, dim=1))
return out


class FPNBlock(nn.Module):
def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True):
super().__init__()
self.skip_conv = (
md.Conv2dReLU(
skip_channels,
pyramid_channels,
kernel_size=1,
use_batchnorm=use_bathcnorm,
)
if skip_channels != 0
else nn.Identity()
)

def forward(self, x, skip):
_, channels, height, weight = skip.shape
x = F.interpolate(
x, size=(height, weight), mode="bilinear", align_corners=False
)
if channels != 0:
skip = self.skip_conv(skip)
x = x + skip
return x


class UPerNetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
encoder_depth=5,
pyramid_channels=256,
segmentation_channels=64,
):
super().__init__()
self.out_channels = segmentation_channels
if encoder_depth < 3:
raise ValueError(
"Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format(
encoder_depth
)
)

encoder_channels = encoder_channels[::-1]

# PSP Module
self.psp = PSPModule(
in_channels=encoder_channels[0],
out_channels=pyramid_channels,
sizes=(1, 2, 3, 6),
use_batchnorm=True,
)

# FPN Module
self.fpn_stages = nn.ModuleList(
[FPNBlock(ch, pyramid_channels) for ch in encoder_channels[1:]]
)

self.fpn_bottleneck = md.Conv2dReLU(
in_channels=(len(encoder_channels) - 1) * pyramid_channels,
out_channels=segmentation_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)

def forward(self, *features):
output_size = features[0].shape[2:]
target_size = [size // 4 for size in output_size]

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder

psp_out = self.psp(features[0])

fpn_features = [psp_out]
for feature, stage in zip(features[1:], self.fpn_stages):
fpn_feature = stage(fpn_features[-1], feature)
fpn_features.append(fpn_feature)

# Resize all FPN features to 1/4 of the original resolution.
resized_fpn_features = []
for feature in fpn_features:
resized_feature = F.interpolate(
feature, size=target_size, mode="bilinear", align_corners=False
)
resized_fpn_features.append(resized_feature)

output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1))
output = F.interpolate(
output, size=output_size, mode="bilinear", align_corners=False
)

return output
91 changes: 91 additions & 0 deletions segmentation_models_pytorch/decoders/upernet/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
)
from .decoder import UPerNetDecoder


class UPerNet(SegmentationModel):
"""UPerNet is a unified perceptual parsing network for image segmentation.

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256
decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)

Returns:
``torch.nn.Module``: **UPerNet**

.. _UPerNet:
https://arxiv.org/abs/1807.10221

"""

def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_pyramid_channels: int = 256,
decoder_segmentation_channels: int = 64,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
):
super().__init__()

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)

self.decoder = UPerNetDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
pyramid_channels=decoder_pyramid_channels,
segmentation_channels=decoder_segmentation_channels,
)

self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
out_channels=classes,
activation=activation,
kernel_size=3,
)

if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None

self.name = "upernet-{}".format(encoder_name)
self.initialize()
4 changes: 3 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_sample(model_class):
smp.PSPNet,
smp.UnetPlusPlus,
smp.MAnet,
smp.UPerNet,
]:
sample = torch.ones([1, 3, 64, 64])
elif model_class == smp.PAN:
Expand Down Expand Up @@ -57,7 +58,8 @@ def _test_forward_backward(model, sample, test_shape=False):
@pytest.mark.parametrize("encoder_name", ENCODERS)
@pytest.mark.parametrize("encoder_depth", [3, 5])
@pytest.mark.parametrize(
"model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]
"model_class",
[smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet],
)
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
if (
Expand Down