8
8
from ._base import EncoderMixin
9
9
10
10
11
- def get_efficientnet_kwargs (channel_multiplier = 1.0 , depth_multiplier = 1.0 ):
11
+ def get_efficientnet_kwargs (channel_multiplier = 1.0 , depth_multiplier = 1.0 , drop_rate = 0.2 ):
12
12
"""Creates an EfficientNet model.
13
13
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
14
14
Paper: https://arxiv.org/abs/1905.11946
@@ -44,24 +44,62 @@ def get_efficientnet_kwargs(channel_multiplier=1.0, depth_multiplier=1.0):
44
44
channel_multiplier = channel_multiplier ,
45
45
act_layer = Swish ,
46
46
norm_kwargs = {}, # TODO: check
47
- drop_rate = 0.2 ,
47
+ drop_rate = drop_rate ,
48
48
drop_path_rate = 0.2 ,
49
49
)
50
50
return model_kwargs
51
51
52
+ def gen_efficientnet_lite_kwargs (channel_multiplier = 1.0 , depth_multiplier = 1.0 , drop_rate = 0.2 ):
53
+ """Creates an EfficientNet-Lite model.
52
54
53
- class EfficientNetEncoder (EfficientNet , EncoderMixin ):
55
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
56
+ Paper: https://arxiv.org/abs/1905.11946
57
+
58
+ EfficientNet params
59
+ name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
60
+ 'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
61
+ 'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
62
+ 'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
63
+ 'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
64
+ 'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
65
+
66
+ Args:
67
+ channel_multiplier: multiplier to number of channels per layer
68
+ depth_multiplier: multiplier to number of repeats per stage
69
+ """
70
+ arch_def = [
71
+ ['ds_r1_k3_s1_e1_c16' ],
72
+ ['ir_r2_k3_s2_e6_c24' ],
73
+ ['ir_r2_k5_s2_e6_c40' ],
74
+ ['ir_r3_k3_s2_e6_c80' ],
75
+ ['ir_r3_k5_s1_e6_c112' ],
76
+ ['ir_r4_k5_s2_e6_c192' ],
77
+ ['ir_r1_k3_s1_e6_c320' ],
78
+ ]
79
+ model_kwargs = dict (
80
+ block_args = decode_arch_def (arch_def , depth_multiplier , fix_first_last = True ),
81
+ num_features = 1280 ,
82
+ stem_size = 32 ,
83
+ fix_stem = True ,
84
+ channel_multiplier = channel_multiplier ,
85
+ act_layer = nn .ReLU6 ,
86
+ norm_kwargs = {},
87
+ drop_rate = drop_rate ,
88
+ drop_path_rate = 0.2 ,
89
+ )
90
+ return model_kwargs
91
+
92
+ class EfficientNetBaseEncoder (EfficientNet , EncoderMixin ):
54
93
55
- def __init__ (self , stage_idxs , out_channels , depth = 5 , channel_multiplier = 1.0 , depth_multiplier = 1.0 ):
56
- kwargs = get_efficientnet_kwargs (channel_multiplier , depth_multiplier )
57
- super ().__init__ (** kwargs )
94
+ def __init__ (self , stage_idxs , out_channels , depth = 5 , ** kwargs ):
95
+ super ().__init__ (** kwargs )
58
96
59
- self ._stage_idxs = stage_idxs
60
- self ._out_channels = out_channels
61
- self ._depth = depth
62
- self ._in_channels = 3
97
+ self ._stage_idxs = stage_idxs
98
+ self ._out_channels = out_channels
99
+ self ._depth = depth
100
+ self ._in_channels = 3
63
101
64
- del self .classifier
102
+ del self .classifier
65
103
66
104
def get_stages (self ):
67
105
return [
@@ -89,6 +127,20 @@ def load_state_dict(self, state_dict, **kwargs):
89
127
super ().load_state_dict (state_dict , ** kwargs )
90
128
91
129
130
+ class EfficientNetEncoder (EfficientNetBaseEncoder ):
131
+
132
+ def __init__ (self , stage_idxs , out_channels , depth = 5 , channel_multiplier = 1.0 , depth_multiplier = 1.0 , drop_rate = 0.2 ):
133
+ kwargs = get_efficientnet_kwargs (channel_multiplier , depth_multiplier , drop_rate )
134
+ super ().__init__ (stage_idxs , out_channels , depth , ** kwargs )
135
+
136
+
137
+ class EfficientNetLiteEncoder (EfficientNetBaseEncoder ):
138
+
139
+ def __init__ (self , stage_idxs , out_channels , depth = 5 , channel_multiplier = 1.0 , depth_multiplier = 1.0 , drop_rate = 0.2 ):
140
+ kwargs = gen_efficientnet_lite_kwargs (channel_multiplier , depth_multiplier , drop_rate )
141
+ super ().__init__ (stage_idxs , out_channels , depth , ** kwargs )
142
+
143
+
92
144
def prepare_settings (settings ):
93
145
return {
94
146
"mean" : settings ["mean" ],
@@ -113,6 +165,7 @@ def prepare_settings(settings):
113
165
"stage_idxs" : (2 , 3 , 5 ),
114
166
"channel_multiplier" : 1.0 ,
115
167
"depth_multiplier" : 1.0 ,
168
+ "drop_rate" : 0.2 ,
116
169
},
117
170
},
118
171
@@ -128,6 +181,7 @@ def prepare_settings(settings):
128
181
"stage_idxs" : (2 , 3 , 5 ),
129
182
"channel_multiplier" : 1.0 ,
130
183
"depth_multiplier" : 1.1 ,
184
+ "drop_rate" : 0.2 ,
131
185
},
132
186
},
133
187
@@ -143,6 +197,7 @@ def prepare_settings(settings):
143
197
"stage_idxs" : (2 , 3 , 5 ),
144
198
"channel_multiplier" : 1.1 ,
145
199
"depth_multiplier" : 1.2 ,
200
+ "drop_rate" : 0.3 ,
146
201
},
147
202
},
148
203
@@ -158,6 +213,7 @@ def prepare_settings(settings):
158
213
"stage_idxs" : (2 , 3 , 5 ),
159
214
"channel_multiplier" : 1.2 ,
160
215
"depth_multiplier" : 1.4 ,
216
+ "drop_rate" : 0.3 ,
161
217
},
162
218
},
163
219
@@ -173,6 +229,7 @@ def prepare_settings(settings):
173
229
"stage_idxs" : (2 , 3 , 5 ),
174
230
"channel_multiplier" : 1.4 ,
175
231
"depth_multiplier" : 1.8 ,
232
+ "drop_rate" : 0.4 ,
176
233
},
177
234
},
178
235
@@ -188,6 +245,7 @@ def prepare_settings(settings):
188
245
"stage_idxs" : (2 , 3 , 5 ),
189
246
"channel_multiplier" : 1.6 ,
190
247
"depth_multiplier" : 2.2 ,
248
+ "drop_rate" : 0.4 ,
191
249
},
192
250
},
193
251
@@ -203,6 +261,7 @@ def prepare_settings(settings):
203
261
"stage_idxs" : (2 , 3 , 5 ),
204
262
"channel_multiplier" : 1.8 ,
205
263
"depth_multiplier" : 2.6 ,
264
+ "drop_rate" : 0.5 ,
206
265
},
207
266
},
208
267
@@ -218,6 +277,7 @@ def prepare_settings(settings):
218
277
"stage_idxs" : (2 , 3 , 5 ),
219
278
"channel_multiplier" : 2.0 ,
220
279
"depth_multiplier" : 3.1 ,
280
+ "drop_rate" : 0.5 ,
221
281
},
222
282
},
223
283
@@ -232,6 +292,7 @@ def prepare_settings(settings):
232
292
"stage_idxs" : (2 , 3 , 5 ),
233
293
"channel_multiplier" : 2.2 ,
234
294
"depth_multiplier" : 3.6 ,
295
+ "drop_rate" : 0.5 ,
235
296
},
236
297
},
237
298
@@ -245,6 +306,77 @@ def prepare_settings(settings):
245
306
"stage_idxs" : (2 , 3 , 5 ),
246
307
"channel_multiplier" : 4.3 ,
247
308
"depth_multiplier" : 5.3 ,
309
+ "drop_rate" : 0.5 ,
310
+ },
311
+ },
312
+
313
+ "timm-tf_efficientnet_lite0" : {
314
+ "encoder" : EfficientNetLiteEncoder ,
315
+ "pretrained_settings" : {
316
+ "imagenet" : prepare_settings (default_cfgs ["tf_efficientnet_lite0" ]),
317
+ },
318
+ "params" : {
319
+ "out_channels" : (3 , 32 , 24 , 40 , 112 , 320 ),
320
+ "stage_idxs" : (2 , 3 , 5 ),
321
+ "channel_multiplier" : 1.0 ,
322
+ "depth_multiplier" : 1.0 ,
323
+ "drop_rate" : 0.2 ,
324
+ },
325
+ },
326
+
327
+ "timm-tf_efficientnet_lite1" : {
328
+ "encoder" : EfficientNetLiteEncoder ,
329
+ "pretrained_settings" : {
330
+ "imagenet" : prepare_settings (default_cfgs ["tf_efficientnet_lite1" ]),
331
+ },
332
+ "params" : {
333
+ "out_channels" : (3 , 32 , 24 , 40 , 112 , 320 ),
334
+ "stage_idxs" : (2 , 3 , 5 ),
335
+ "channel_multiplier" : 1.0 ,
336
+ "depth_multiplier" : 1.1 ,
337
+ "drop_rate" : 0.2 ,
338
+ },
339
+ },
340
+
341
+ "timm-tf_efficientnet_lite2" : {
342
+ "encoder" : EfficientNetLiteEncoder ,
343
+ "pretrained_settings" : {
344
+ "imagenet" : prepare_settings (default_cfgs ["tf_efficientnet_lite2" ]),
345
+ },
346
+ "params" : {
347
+ "out_channels" : (3 , 32 , 24 , 48 , 120 , 352 ),
348
+ "stage_idxs" : (2 , 3 , 5 ),
349
+ "channel_multiplier" : 1.1 ,
350
+ "depth_multiplier" : 1.2 ,
351
+ "drop_rate" : 0.3 ,
352
+ },
353
+ },
354
+
355
+ "timm-tf_efficientnet_lite3" : {
356
+ "encoder" : EfficientNetLiteEncoder ,
357
+ "pretrained_settings" : {
358
+ "imagenet" : prepare_settings (default_cfgs ["tf_efficientnet_lite3" ]),
359
+ },
360
+ "params" : {
361
+ "out_channels" : (3 , 32 , 32 , 48 , 136 , 384 ),
362
+ "stage_idxs" : (2 , 3 , 5 ),
363
+ "channel_multiplier" : 1.2 ,
364
+ "depth_multiplier" : 1.4 ,
365
+ "drop_rate" : 0.3 ,
366
+ },
367
+ },
368
+
369
+ "timm-tf_efficientnet_lite4" : {
370
+ "encoder" : EfficientNetLiteEncoder ,
371
+ "pretrained_settings" : {
372
+ "imagenet" : prepare_settings (default_cfgs ["tf_efficientnet_lite4" ]),
373
+ },
374
+ "params" : {
375
+ "out_channels" : (3 , 32 , 32 , 56 , 160 , 448 ),
376
+ "stage_idxs" : (2 , 3 , 5 ),
377
+ "channel_multiplier" : 1.4 ,
378
+ "depth_multiplier" : 1.8 ,
379
+ "drop_rate" : 0.4 ,
248
380
},
249
381
},
250
382
}
0 commit comments