Skip to content

Commit 792c273

Browse files
authored
Segformer backbone Mix Visual Transformer (#632)
* Segformer backbone * Add limitations for FPN, Unet++, Linknet * fix tests
1 parent f58fd6d commit 792c273

File tree

8 files changed

+668
-3
lines changed

8 files changed

+668
-3
lines changed

README.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ The main features of this library are:
2020

2121
- High level API (just two lines to create a neural network)
2222
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
23-
- 113 available encoders (and 400+ encoders from [timm](https://github.com/rwightman/pytorch-image-models))
23+
- 119 available encoders (and 400+ encoders from [timm](https://github.com/rwightman/pytorch-image-models))
2424
- All encoders have pre-trained weights for faster and better convergence
2525
- Popular metrics and losses for training routines
2626

@@ -352,6 +352,29 @@ The following is a list of supported encoders in the SMP. Select the appropriate
352352
</div>
353353
</details>
354354

355+
<details>
356+
<summary style="margin-left: 25px;">Mix Vision Transformer</summary>
357+
<div style="margin-left: 25px;">
358+
359+
Backbone from SegFormer pretrained on Imagenet! Can be used with other decoders from package, you can combine Mix Visual Transformer with Unet, FPN and others!
360+
361+
Limitations:
362+
363+
- encoder is not supported by Linknet, Unet++
364+
- encoder is not supported by FPN if encoder depth != 5
365+
366+
|Encoder |Weights |Params, M |
367+
|--------------------------------|:------------------------------:|:------------------------------:|
368+
|mit_b0 |imagenet |3M |
369+
|mit_b1 |imagenet |13M |
370+
|mit_b2 |imagenet |24M |
371+
|mit_b3 |imagenet |44M |
372+
|mit_b4 |imagenet |60M |
373+
|mit_b5 |imagenet |81M |
374+
375+
</div>
376+
</details>
377+
355378

356379
\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).
357380

docs/encoders.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,18 @@ VGG
324324
+-------------+------------+-------------+
325325
| vgg19\_bn | imagenet | 20M |
326326
+-------------+------------+-------------+
327+
328+
329+
Mix Visual Transformer
330+
~~~~~~~~~~~~~~~~~~~~~
331+
332+
+-----------+----------+------------+
333+
| Encoder | Weights | Params, M |
334+
+===========+==========+============+
335+
| mit\_b0 | imagenet | 3M |
336+
| mit\_b1 | imagenet | 13M |
337+
| mit\_b2 | imagenet | 24M |
338+
| mit\_b3 | imagenet | 44M |
339+
| mit\_b4 | imagenet | 60M |
340+
| mit\_b5 | imagenet | 81M |
341+
+-----------+----------+------------+

segmentation_models_pytorch/decoders/fpn/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def __init__(
6666
):
6767
super().__init__()
6868

69+
# validate input params
70+
if encoder_name.startswith("mit_b") and encoder_depth != 5:
71+
raise ValueError("Encoder {} support only encoder_depth=5".format(encoder_name))
72+
6973
self.encoder = get_encoder(
7074
encoder_name,
7175
in_channels=in_channels,

segmentation_models_pytorch/decoders/linknet/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __init__(
6464
):
6565
super().__init__()
6666

67+
if encoder_name.startswith("mit_b"):
68+
raise ValueError("Encoder `{}` is not supported for Linknet".format(encoder_name))
69+
6770
self.encoder = get_encoder(
6871
encoder_name,
6972
in_channels=in_channels,

segmentation_models_pytorch/decoders/unetplusplus/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def __init__(
6868
):
6969
super().__init__()
7070

71+
if encoder_name.startswith("mit_b"):
72+
raise ValueError("UnetPlusPlus is not support encoder_name={}".format(encoder_name))
73+
7174
self.encoder = get_encoder(
7275
encoder_name,
7376
in_channels=in_channels,

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .timm_sknet import timm_sknet_encoders
2020
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
2121
from .timm_gernet import timm_gernet_encoders
22+
from .mix_transformer import mix_transformer_encoders
2223

2324
from .timm_universal import TimmUniversalEncoder
2425

@@ -42,6 +43,7 @@
4243
encoders.update(timm_sknet_encoders)
4344
encoders.update(timm_mobilenetv3_encoders)
4445
encoders.update(timm_gernet_encoders)
46+
encoders.update(mix_transformer_encoders)
4547

4648

4749
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):

0 commit comments

Comments
 (0)