Skip to content

Commit 04f59db

Browse files
authored
handle different drop rates for EfficientNet, add timm-tf_efficientnet_lite0-lite4 (#314)
* handle different drop rates for EfficientNet, add timm-tf_efficientnet_lite0-lite4 its not clear which dataset lite0-4 were pretrained with, I set it to 'imagenet' but I've noticed the mean=(0.5, 0.5, 0.5) * readme * minor typos * correct weight from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite * fix lite3/4 models * Update encoders.rst
1 parent 5f5f639 commit 04f59db

File tree

3 files changed

+168
-21
lines changed

3 files changed

+168
-21
lines changed

README.md

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
1010

1111
The main features of this library are:
1212

13-
- High level API (just two lines to create neural network)
13+
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 99 available encoders
15+
- 104 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

1818
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
@@ -23,7 +23,7 @@ Visit [Read The Docs Project Page](https://smp.readthedocs.io/) or read followin
2323
1. [Quick start](#start)
2424
2. [Examples](#examples)
2525
3. [Models](#models)
26-
1. [Architectures](#architectires)
26+
1. [Architectures](#architectures)
2727
2. [Encoders](#encoders)
2828
4. [Models API](#api)
2929
1. [Input channels](#input-channels)
@@ -46,13 +46,13 @@ import segmentation_models_pytorch as smp
4646

4747
model = smp.Unet(
4848
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
49-
encoder_weights="imagenet", # use `imagenet` pretreined weights for encoder initialization
49+
encoder_weights="imagenet", # use `imagenet` pretrained weights for encoder initialization
5050
in_channels=1, # model input channels (1 for grayscale images, 3 for RGB, etc.)
5151
classes=3, # model output channels (number of classes in your dataset)
5252
)
5353
```
5454
- see [table](#architectires) with available model architectures
55-
- see [table](#encoders) with avaliable encoders and its corresponding weights
55+
- see [table](#encoders) with available encoders and their corresponding weights
5656

5757
#### 2. Configure data preprocessing
5858

@@ -73,7 +73,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
7373

7474
### 📦 Models <a name="models"></a>
7575

76-
#### Architectures <a name="architectires"></a>
76+
#### Architectures <a name="architectures"></a>
7777
- Unet [[paper](https://arxiv.org/abs/1505.04597)] [[docs](https://smp.readthedocs.io/en/latest/models.html#unet)]
7878
- Unet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)]
7979
- MAnet [[paper](https://ieeexplore.ieee.org/abstract/document/9201310)] [[docs](https://smp.readthedocs.io/en/latest/models.html#manet)]
@@ -268,6 +268,11 @@ The following is a list of supported encoders in the SMP. Select the appropriate
268268
|timm-efficientnet-b7 |imagenet / advprop / noisy-student|63M |
269269
|timm-efficientnet-b8 |imagenet / advprop |84M |
270270
|timm-efficientnet-l2 |noisy-student |474M |
271+
|timm-efficientnet-lite0 |imagenet |4M |
272+
|timm-efficientnet-lite1 |imagenet |5M |
273+
|timm-efficientnet-lite2 |imagenet |6M |
274+
|timm-efficientnet-lite3 |imagenet |8M |
275+
|timm-efficientnet-lite4 |imagenet |13M |
271276

272277
</div>
273278
</details>
@@ -330,7 +335,7 @@ The following is a list of supported encoders in the SMP. Select the appropriate
330335
- `model.forward(x)` - sequentially pass `x` through model\`s encoder, decoder and segmentation head (and classification head if specified)
331336

332337
##### Input channels
333-
Input channels parameter allow you to create models, which process tensors with arbitrary number of channels.
338+
Input channels parameter allows you to create models, which process tensors with arbitrary number of channels.
334339
If you use pretrained weights from imagenet - weights of first convolution will be reused for
335340
1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly.
336341
```python
@@ -340,9 +345,9 @@ mask = model(torch.ones([1, 1, 64, 64]))
340345

341346
##### Auxiliary classification output
342347
All models support `aux_params` parameters, which is default set to `None`.
343-
If `aux_params = None` than classification auxiliary output is not created, else
348+
If `aux_params = None` then classification auxiliary output is not created, else
344349
model produce not only `mask`, but also `label` output with shape `NC`.
345-
Classification head consist of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
350+
Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
346351
configured by `aux_params` as follows:
347352
```python
348353
aux_params=dict(
@@ -357,7 +362,7 @@ mask, label = model(x)
357362

358363
##### Depth
359364
Depth parameter specify a number of downsampling operations in encoder, so you can make
360-
your model lighted if specify smaller `depth`.
365+
your model lighter if specify smaller `depth`.
361366
```python
362367
model = smp.Unet('resnet34', encoder_depth=4)
363368
```

docs/encoders.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ EfficientNet
238238
+------------------------+--------------------------------------+-------------+
239239
| timm-efficientnet-l2 | noisy-student | 474M |
240240
+------------------------+--------------------------------------+-------------+
241+
| timm-efficientnet-lite0| imagenet | 4M |
242+
+------------------------+--------------------------------------+-------------+
243+
| timm-efficientnet-lite1| imagenet | 4M |
244+
+------------------------+--------------------------------------+-------------+
245+
| timm-efficientnet-lite2| imagenet | 6M |
246+
+------------------------+--------------------------------------+-------------+
247+
| timm-efficientnet-lite3| imagenet | 8M |
248+
+------------------------+--------------------------------------+-------------+
249+
| timm-efficientnet-lite4| imagenet | 13M |
250+
+------------------------+--------------------------------------+-------------+
241251

242252
MobileNet
243253
~~~~~~~~~

segmentation_models_pytorch/encoders/timm_efficientnet.py

Lines changed: 143 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ._base import EncoderMixin
99

1010

11-
def get_efficientnet_kwargs(channel_multiplier=1.0, depth_multiplier=1.0):
11+
def get_efficientnet_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
1212
"""Creates an EfficientNet model.
1313
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
1414
Paper: https://arxiv.org/abs/1905.11946
@@ -44,24 +44,62 @@ def get_efficientnet_kwargs(channel_multiplier=1.0, depth_multiplier=1.0):
4444
channel_multiplier=channel_multiplier,
4545
act_layer=Swish,
4646
norm_kwargs={}, # TODO: check
47-
drop_rate=0.2,
47+
drop_rate=drop_rate,
4848
drop_path_rate=0.2,
4949
)
5050
return model_kwargs
5151

52+
def gen_efficientnet_lite_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
53+
"""Creates an EfficientNet-Lite model.
5254
53-
class EfficientNetEncoder(EfficientNet, EncoderMixin):
55+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
56+
Paper: https://arxiv.org/abs/1905.11946
57+
58+
EfficientNet params
59+
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
60+
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
61+
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
62+
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
63+
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
64+
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
65+
66+
Args:
67+
channel_multiplier: multiplier to number of channels per layer
68+
depth_multiplier: multiplier to number of repeats per stage
69+
"""
70+
arch_def = [
71+
['ds_r1_k3_s1_e1_c16'],
72+
['ir_r2_k3_s2_e6_c24'],
73+
['ir_r2_k5_s2_e6_c40'],
74+
['ir_r3_k3_s2_e6_c80'],
75+
['ir_r3_k5_s1_e6_c112'],
76+
['ir_r4_k5_s2_e6_c192'],
77+
['ir_r1_k3_s1_e6_c320'],
78+
]
79+
model_kwargs = dict(
80+
block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
81+
num_features=1280,
82+
stem_size=32,
83+
fix_stem=True,
84+
channel_multiplier=channel_multiplier,
85+
act_layer=nn.ReLU6,
86+
norm_kwargs={},
87+
drop_rate=drop_rate,
88+
drop_path_rate=0.2,
89+
)
90+
return model_kwargs
91+
92+
class EfficientNetBaseEncoder(EfficientNet, EncoderMixin):
5493

55-
def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0):
56-
kwargs = get_efficientnet_kwargs(channel_multiplier, depth_multiplier)
57-
super().__init__(**kwargs)
94+
def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
95+
super().__init__(**kwargs)
5896

59-
self._stage_idxs = stage_idxs
60-
self._out_channels = out_channels
61-
self._depth = depth
62-
self._in_channels = 3
97+
self._stage_idxs = stage_idxs
98+
self._out_channels = out_channels
99+
self._depth = depth
100+
self._in_channels = 3
63101

64-
del self.classifier
102+
del self.classifier
65103

66104
def get_stages(self):
67105
return [
@@ -89,6 +127,20 @@ def load_state_dict(self, state_dict, **kwargs):
89127
super().load_state_dict(state_dict, **kwargs)
90128

91129

130+
class EfficientNetEncoder(EfficientNetBaseEncoder):
131+
132+
def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
133+
kwargs = get_efficientnet_kwargs(channel_multiplier, depth_multiplier, drop_rate)
134+
super().__init__(stage_idxs, out_channels, depth, **kwargs)
135+
136+
137+
class EfficientNetLiteEncoder(EfficientNetBaseEncoder):
138+
139+
def __init__(self, stage_idxs, out_channels, depth=5, channel_multiplier=1.0, depth_multiplier=1.0, drop_rate=0.2):
140+
kwargs = gen_efficientnet_lite_kwargs(channel_multiplier, depth_multiplier, drop_rate)
141+
super().__init__(stage_idxs, out_channels, depth, **kwargs)
142+
143+
92144
def prepare_settings(settings):
93145
return {
94146
"mean": settings["mean"],
@@ -113,6 +165,7 @@ def prepare_settings(settings):
113165
"stage_idxs": (2, 3, 5),
114166
"channel_multiplier": 1.0,
115167
"depth_multiplier": 1.0,
168+
"drop_rate": 0.2,
116169
},
117170
},
118171

@@ -128,6 +181,7 @@ def prepare_settings(settings):
128181
"stage_idxs": (2, 3, 5),
129182
"channel_multiplier": 1.0,
130183
"depth_multiplier": 1.1,
184+
"drop_rate": 0.2,
131185
},
132186
},
133187

@@ -143,6 +197,7 @@ def prepare_settings(settings):
143197
"stage_idxs": (2, 3, 5),
144198
"channel_multiplier": 1.1,
145199
"depth_multiplier": 1.2,
200+
"drop_rate": 0.3,
146201
},
147202
},
148203

@@ -158,6 +213,7 @@ def prepare_settings(settings):
158213
"stage_idxs": (2, 3, 5),
159214
"channel_multiplier": 1.2,
160215
"depth_multiplier": 1.4,
216+
"drop_rate": 0.3,
161217
},
162218
},
163219

@@ -173,6 +229,7 @@ def prepare_settings(settings):
173229
"stage_idxs": (2, 3, 5),
174230
"channel_multiplier": 1.4,
175231
"depth_multiplier": 1.8,
232+
"drop_rate": 0.4,
176233
},
177234
},
178235

@@ -188,6 +245,7 @@ def prepare_settings(settings):
188245
"stage_idxs": (2, 3, 5),
189246
"channel_multiplier": 1.6,
190247
"depth_multiplier": 2.2,
248+
"drop_rate": 0.4,
191249
},
192250
},
193251

@@ -203,6 +261,7 @@ def prepare_settings(settings):
203261
"stage_idxs": (2, 3, 5),
204262
"channel_multiplier": 1.8,
205263
"depth_multiplier": 2.6,
264+
"drop_rate": 0.5,
206265
},
207266
},
208267

@@ -218,6 +277,7 @@ def prepare_settings(settings):
218277
"stage_idxs": (2, 3, 5),
219278
"channel_multiplier": 2.0,
220279
"depth_multiplier": 3.1,
280+
"drop_rate": 0.5,
221281
},
222282
},
223283

@@ -232,6 +292,7 @@ def prepare_settings(settings):
232292
"stage_idxs": (2, 3, 5),
233293
"channel_multiplier": 2.2,
234294
"depth_multiplier": 3.6,
295+
"drop_rate": 0.5,
235296
},
236297
},
237298

@@ -245,6 +306,77 @@ def prepare_settings(settings):
245306
"stage_idxs": (2, 3, 5),
246307
"channel_multiplier": 4.3,
247308
"depth_multiplier": 5.3,
309+
"drop_rate": 0.5,
310+
},
311+
},
312+
313+
"timm-tf_efficientnet_lite0": {
314+
"encoder": EfficientNetLiteEncoder,
315+
"pretrained_settings": {
316+
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite0"]),
317+
},
318+
"params": {
319+
"out_channels": (3, 32, 24, 40, 112, 320),
320+
"stage_idxs": (2, 3, 5),
321+
"channel_multiplier": 1.0,
322+
"depth_multiplier": 1.0,
323+
"drop_rate": 0.2,
324+
},
325+
},
326+
327+
"timm-tf_efficientnet_lite1": {
328+
"encoder": EfficientNetLiteEncoder,
329+
"pretrained_settings": {
330+
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite1"]),
331+
},
332+
"params": {
333+
"out_channels": (3, 32, 24, 40, 112, 320),
334+
"stage_idxs": (2, 3, 5),
335+
"channel_multiplier": 1.0,
336+
"depth_multiplier": 1.1,
337+
"drop_rate": 0.2,
338+
},
339+
},
340+
341+
"timm-tf_efficientnet_lite2": {
342+
"encoder": EfficientNetLiteEncoder,
343+
"pretrained_settings": {
344+
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite2"]),
345+
},
346+
"params": {
347+
"out_channels": (3, 32, 24, 48, 120, 352),
348+
"stage_idxs": (2, 3, 5),
349+
"channel_multiplier": 1.1,
350+
"depth_multiplier": 1.2,
351+
"drop_rate": 0.3,
352+
},
353+
},
354+
355+
"timm-tf_efficientnet_lite3": {
356+
"encoder": EfficientNetLiteEncoder,
357+
"pretrained_settings": {
358+
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite3"]),
359+
},
360+
"params": {
361+
"out_channels": (3, 32, 32, 48, 136, 384),
362+
"stage_idxs": (2, 3, 5),
363+
"channel_multiplier": 1.2,
364+
"depth_multiplier": 1.4,
365+
"drop_rate": 0.3,
366+
},
367+
},
368+
369+
"timm-tf_efficientnet_lite4": {
370+
"encoder": EfficientNetLiteEncoder,
371+
"pretrained_settings": {
372+
"imagenet": prepare_settings(default_cfgs["tf_efficientnet_lite4"]),
373+
},
374+
"params": {
375+
"out_channels": (3, 32, 32, 56, 160, 448),
376+
"stage_idxs": (2, 3, 5),
377+
"channel_multiplier": 1.4,
378+
"depth_multiplier": 1.8,
379+
"drop_rate": 0.4,
248380
},
249381
},
250382
}

0 commit comments

Comments
 (0)