Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
The main features of this library are:

- High-level API (just two lines to create a neural network)
- 10 models architectures for binary and multi class segmentation (including legendary Unet)
- 11 models architectures for binary and multi class segmentation (including legendary Unet)
- 124 available encoders (and 500+ encoders from [timm](https://github.com/rwightman/pytorch-image-models))
- All encoders have pre-trained weights for faster and better convergence
- Popular metrics and losses for training routines
Expand Down Expand Up @@ -95,6 +95,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
- DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)]
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
- UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)]
- Segformer [[paper](https://arxiv.org/abs/2105.15203)] [[docs](https://smp.readthedocs.io/en/latest/models.html#segformer)]

#### Encoders <a name="encoders"></a>

Expand Down
8 changes: 8 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,11 @@ PAN
UPerNet
~~~~~~~
.. autoclass:: segmentation_models_pytorch.UPerNet


.. _segformer:

Segformer
~~~~~~~~~
.. autoclass:: segmentation_models_pytorch.Segformer

3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .decoders.segformer import Segformer
from .base.hub_mixin import from_pretrained

from .__version__ import __version__
Expand Down Expand Up @@ -50,6 +51,7 @@ def create_model(
DeepLabV3Plus,
PAN,
UPerNet,
Segformer,
]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
Expand Down Expand Up @@ -85,6 +87,7 @@ def create_model(
"DeepLabV3Plus",
"PAN",
"UPerNet",
"Segformer",
"from_pretrained",
"create_model",
"__version__",
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/segformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import Segformer

__all__ = ["Segformer"]
72 changes: 72 additions & 0 deletions segmentation_models_pytorch/decoders/segformer/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from segmentation_models_pytorch.base import modules as md


class MLP(nn.Module):
def __init__(self, skip_channels, segmentation_channels):
super().__init__()

self.linear = nn.Linear(skip_channels, segmentation_channels)

def forward(self, x: torch.Tensor):
batch, _, height, width = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.linear(x)
x = x.transpose(1, 2).reshape(batch, -1, height, width).contiguous()
return x


class SegformerDecoder(nn.Module):
def __init__(
self,
encoder_channels,
encoder_depth=5,
segmentation_channels=256,
):
super().__init__()

if encoder_depth < 3:
raise ValueError(
"Encoder depth for Segformer decoder cannot be less than 3, got {}.".format(
encoder_depth
)
)

if encoder_channels[1] == 0:
encoder_channels = tuple(
channel for index, channel in enumerate(encoder_channels) if index != 1
)
encoder_channels = encoder_channels[::-1]

self.mlp_stage = nn.ModuleList(
[MLP(channel, segmentation_channels) for channel in encoder_channels[:-1]]
)

self.fuse_stage = md.Conv2dReLU(
in_channels=(len(encoder_channels) - 1) * segmentation_channels,
out_channels=segmentation_channels,
kernel_size=1,
use_batchnorm=True,
)

def forward(self, *features):
# Resize all features to the size of the largest feature
target_size = [dim // 4 for dim in features[0].shape[2:]]

features = features[2:] if features[1].size(1) == 0 else features[1:]
features = features[::-1] # reverse channels to start from head of encoder

resized_features = []
for feature, stage in zip(features, self.mlp_stage):
feature = stage(feature)
resized_feature = F.interpolate(
feature, size=target_size, mode="bilinear", align_corners=False
)
resized_features.append(resized_feature)

output = self.fuse_stage(torch.cat(resized_features, dim=1))

return output
93 changes: 93 additions & 0 deletions segmentation_models_pytorch/decoders/segformer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Any, Optional, Union

from segmentation_models_pytorch.base import (
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import SegformerDecoder


class Segformer(SegmentationModel):
"""Segformer is simple and efficient design for semantic segmentation with Transformers

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 256
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **Segformer**

.. _Segformer:
https://arxiv.org/abs/2105.15203

"""

def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_segmentation_channels: int = 256,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = SegformerDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
segmentation_channels=decoder_segmentation_channels,
)

self.segmentation_head = SegmentationHead(
in_channels=decoder_segmentation_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=4,
)

if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None

self.name = "segformer-{}".format(encoder_name)
self.initialize()
14 changes: 13 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_sample(model_class):
smp.Unet,
smp.UnetPlusPlus,
smp.MAnet,
smp.Segformer,
]:
sample = torch.ones([1, 3, 64, 64])
elif model_class == smp.PAN:
Expand Down Expand Up @@ -61,7 +62,16 @@ def _test_forward_backward(model, sample, test_shape=False):
@pytest.mark.parametrize("encoder_depth", [3, 5])
@pytest.mark.parametrize(
"model_class",
[smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet],
[
smp.FPN,
smp.PSPNet,
smp.Linknet,
smp.Unet,
smp.UnetPlusPlus,
smp.MAnet,
smp.UPerNet,
smp.Segformer,
],
)
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
if (
Expand Down Expand Up @@ -106,6 +116,7 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
smp.DeepLabV3,
smp.DeepLabV3Plus,
smp.UPerNet,
smp.Segformer,
],
)
def test_forward_backward(model_class):
Expand All @@ -127,6 +138,7 @@ def test_forward_backward(model_class):
smp.DeepLabV3,
smp.DeepLabV3Plus,
smp.UPerNet,
smp.Segformer,
],
)
def test_aux_output(model_class):
Expand Down