Skip to content

Commit ad3e5c1

Browse files
authored
Timm universal encoder (qubvel#433)
1 parent 914f2bf commit ad3e5c1

File tree

10 files changed

+1103
-42
lines changed

10 files changed

+1103
-42
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Visit [Read The Docs Project Page](https://smp.readthedocs.io/) or read followin
2525
3. [Models](#models)
2626
1. [Architectures](#architectures)
2727
2. [Encoders](#encoders)
28+
3. [Timm Encoders](#timm)
2829
4. [Models API](#api)
2930
1. [Input channels](#input-channels)
3031
2. [Auxiliary classification output](#auxiliary-classification-output)
@@ -344,6 +345,17 @@ The following is a list of supported encoders in the SMP. Select the appropriate
344345

345346
\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).
346347

348+
#### Timm Encoders <a name="timm"></a>
349+
350+
[docs](https://smp.readthedocs.io/en/latest/encoders_timm.html)
351+
352+
Pytorch Image Models (a.k.a. timm) has a lot of pretrained models and interface which allows using these models as encoders in smp, however, not all models are supported
353+
354+
- transformer models do not have ``features_only`` functionality implemented
355+
- some models do not have appropriate strides
356+
357+
Total number of supported encoders: 467
358+
- [table with available encoders](https://smp.readthedocs.io/en/latest/encoders_timm.html)
347359

348360
### 🔁 Models API <a name="api"></a>
349361

docs/encoders_timm.rst

Lines changed: 955 additions & 0 deletions
Large diffs are not rendered by default.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Welcome to Segmentation Models's documentation!
1414
quickstart
1515
models
1616
encoders
17+
encoders_timm
1718
losses
1819
insights
1920

misc/generate_table_timm.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import timm
2+
from tqdm import tqdm
3+
4+
5+
def check_features_and_reduction(name):
6+
encoder = timm.create_model(name, features_only=True, pretrained=False)
7+
if not encoder.feature_info.reduction() == [2, 4, 8, 16, 32]:
8+
raise ValueError
9+
10+
def has_dilation_support(name):
11+
try:
12+
timm.create_model(name, features_only=True, output_stride=8, pretrained=False)
13+
timm.create_model(name, features_only=True, output_stride=16, pretrained=False)
14+
return True
15+
except Exception as e:
16+
return False
17+
18+
def make_table(data):
19+
names = supported.keys()
20+
max_len1 = max([len(x) for x in names]) + 2
21+
max_len2 = len("support dilation") + 2
22+
23+
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
24+
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
25+
top = "| " + "Encoder name".ljust(max_len1 - 2) + " | " + "Support dilation".center(max_len2 - 2) + " |\n"
26+
27+
table = l1 + top + l2
28+
29+
for k in sorted(data.keys()):
30+
support = "✅".center(max_len2 - 3) if data[k]["has_dilation"] else " ".center(max_len2 - 2)
31+
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
32+
table += l1
33+
34+
return table
35+
36+
37+
if __name__ == "__main__":
38+
39+
supported_models = {}
40+
41+
with tqdm(timm.list_models()) as names:
42+
for name in names:
43+
try:
44+
check_features_and_reduction(name)
45+
has_dilation = has_dilation_support(name)
46+
supported_models[name] = dict(has_dilation=has_dilation)
47+
except Exception:
48+
continue
49+
50+
table = make_table(supported_models)
51+
print(table)
52+
print(f"Total encoders: {len(supported_models.keys())}")

segmentation_models_pytorch/deeplabv3/model.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,7 @@ def __init__(
5858
in_channels=in_channels,
5959
depth=encoder_depth,
6060
weights=encoder_weights,
61-
)
62-
self.encoder.make_dilated(
63-
stage_list=[4, 5],
64-
dilation_list=[2, 4]
61+
output_stride=8,
6562
)
6663

6764
self.decoder = DeepLabV3Decoder(
@@ -136,29 +133,19 @@ def __init__(
136133
):
137134
super().__init__()
138135

136+
if encoder_output_stride not in [8, 16]:
137+
raise ValueError(
138+
"Encoder output stride should be 8 or 16, got {}".format(encoder_output_stride)
139+
)
140+
139141
self.encoder = get_encoder(
140142
encoder_name,
141143
in_channels=in_channels,
142144
depth=encoder_depth,
143145
weights=encoder_weights,
146+
output_stride=encoder_output_stride,
144147
)
145148

146-
if encoder_output_stride == 8:
147-
self.encoder.make_dilated(
148-
stage_list=[4, 5],
149-
dilation_list=[2, 4]
150-
)
151-
152-
elif encoder_output_stride == 16:
153-
self.encoder.make_dilated(
154-
stage_list=[5],
155-
dilation_list=[2]
156-
)
157-
else:
158-
raise ValueError(
159-
"Encoder output stride should be 8 or 16, got {}".format(encoder_output_stride)
160-
)
161-
162149
self.decoder = DeepLabV3PlusDecoder(
163150
encoder_channels=self.encoder.out_channels,
164151
out_channels=decoder_channels,

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
2020
from .timm_gernet import timm_gernet_encoders
2121

22+
from .timm_universal import TimmUniversalEncoder
23+
2224
from ._preprocessing import preprocess_input
2325

2426
encoders = {}
@@ -41,7 +43,19 @@
4143
encoders.update(timm_gernet_encoders)
4244

4345

44-
def get_encoder(name, in_channels=3, depth=5, weights=None):
46+
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
47+
48+
if name.startswith("tu-"):
49+
name = name.lstrip("tu-")
50+
encoder = TimmUniversalEncoder(
51+
name=name,
52+
in_channels=in_channels,
53+
depth=depth,
54+
output_stride=output_stride,
55+
pretrained=weights is not None,
56+
**kwargs
57+
)
58+
return encoder
4559

4660
try:
4761
Encoder = encoders[name]["encoder"]
@@ -62,7 +76,9 @@ def get_encoder(name, in_channels=3, depth=5, weights=None):
6276
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
6377

6478
encoder.set_in_channels(in_channels, pretrained=weights is not None)
65-
79+
if output_stride != 32:
80+
encoder.make_dilated(output_stride)
81+
6682
return encoder
6783

6884

segmentation_models_pytorch/encoders/_base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,19 @@ def get_stages(self):
3232
"""Method should be overridden in encoder"""
3333
raise NotImplementedError
3434

35-
def make_dilated(self, stage_list, dilation_list):
35+
def make_dilated(self, output_stride):
36+
37+
if output_stride == 16:
38+
stage_list=[5,]
39+
dilation_list=[2,]
40+
41+
elif output_stride == 8:
42+
stage_list=[4, 5]
43+
dilation_list=[2, 4]
44+
45+
else:
46+
raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))
47+
3648
stages = self.get_stages()
3749
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
3850
utils.replace_strides_with_dilation(
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import timm
2+
import torch.nn as nn
3+
4+
5+
class TimmUniversalEncoder(nn.Module):
6+
7+
def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32):
8+
super().__init__()
9+
kwargs = dict(
10+
in_chans=in_channels,
11+
features_only=True,
12+
output_stride=output_stride,
13+
pretrained=pretrained,
14+
out_indices=tuple(range(depth)),
15+
)
16+
17+
# not all models support output stride argument, drop it by default
18+
if output_stride == 32:
19+
kwargs.pop("output_stride")
20+
21+
self.model = timm.create_model(name, **kwargs)
22+
23+
self._in_channels = in_channels
24+
self._out_channels = [3, ] + self.model.feature_info.channels()
25+
self._depth = depth
26+
27+
def forward(self, x):
28+
features = self.model(x)
29+
features = [x,] + features
30+
return features
31+
32+
@property
33+
def out_channels(self):
34+
return self._out_channels

segmentation_models_pytorch/pan/model.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class PAN(SegmentationModel):
1717
to extract features of different spatial resolution
1818
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
1919
other pretrained weights (see table with available weights for each encoder_name)
20-
encoder_dilation: Flag to use dilation in encoder last layer. Doesn't work with ***ception***, **vgg***,
21-
**densenet*`** backbones, default is **True**
20+
encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer.
21+
Doesn't work with ***ception***, **vgg***, **densenet*`** backbones.Default is 16.
2222
decoder_channels: A number of convolution layer filters in decoder blocks
2323
in_channels: A number of input channels for the model, default is 3 (RGB images)
2424
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
@@ -45,7 +45,7 @@ def __init__(
4545
self,
4646
encoder_name: str = "resnet34",
4747
encoder_weights: Optional[str] = "imagenet",
48-
encoder_dilation: bool = True,
48+
encoder_output_stride: int = 16,
4949
decoder_channels: int = 32,
5050
in_channels: int = 3,
5151
classes: int = 1,
@@ -55,19 +55,17 @@ def __init__(
5555
):
5656
super().__init__()
5757

58+
if encoder_output_stride not in [16, 32]:
59+
raise ValueError("PAN support output stride 16 or 32, got {}".format(encoder_output_stride))
60+
5861
self.encoder = get_encoder(
5962
encoder_name,
6063
in_channels=in_channels,
6164
depth=5,
6265
weights=encoder_weights,
66+
output_stride=encoder_output_stride,
6367
)
6468

65-
if encoder_dilation:
66-
self.encoder.make_dilated(
67-
stage_list=[5],
68-
dilation_list=[2]
69-
)
70-
7169
self.decoder = PANDecoder(
7270
encoder_channels=self.encoder.out_channels,
7371
decoder_channels=decoder_channels,

tests/test_models.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,17 @@
88
sys.modules["torchvision._C"] = mock.Mock()
99
import segmentation_models_pytorch as smp
1010

11-
IS_TRAVIS = os.environ.get("TRAVIS", False)
12-
1311

1412
def get_encoders():
15-
travis_exclude_encoders = [
13+
exclude_encoders = [
1614
"senet154",
1715
"resnext101_32x16d",
1816
"resnext101_32x32d",
1917
"resnext101_32x48d",
2018
]
2119
encoders = smp.encoders.get_encoder_names()
22-
if IS_TRAVIS:
23-
encoders = [e for e in encoders if e not in travis_exclude_encoders]
20+
encoders = [e for e in encoders if e not in exclude_encoders]
21+
encoders.append("tu-resnet34") # for timm universal encoder
2422
return encoders
2523

2624

@@ -127,11 +125,7 @@ def test_dilation(encoder_name):
127125
):
128126
return
129127

130-
encoder = smp.encoders.get_encoder(encoder_name)
131-
encoder.make_dilated(
132-
stage_list=[5],
133-
dilation_list=[2],
134-
)
128+
encoder = smp.encoders.get_encoder(encoder_name, output_stride=16)
135129

136130
encoder.eval()
137131
with torch.no_grad():

0 commit comments

Comments
 (0)