13
13
import math
14
14
import collections
15
15
from functools import partial
16
+ from typing import List , Optional
17
+
18
+ # Parameters for the entire model (stem, all blocks, and head)
19
+ GlobalParams = collections .namedtuple (
20
+ "GlobalParams" ,
21
+ [
22
+ "width_coefficient" ,
23
+ "depth_coefficient" ,
24
+ "image_size" ,
25
+ "dropout_rate" ,
26
+ "num_classes" ,
27
+ "batch_norm_momentum" ,
28
+ "batch_norm_epsilon" ,
29
+ "drop_connect_rate" ,
30
+ "depth_divisor" ,
31
+ "min_depth" ,
32
+ "include_top" ,
33
+ ],
34
+ )
35
+
36
+ # Parameters for an individual model block
37
+ BlockArgs = collections .namedtuple (
38
+ "BlockArgs" ,
39
+ [
40
+ "num_repeat" ,
41
+ "kernel_size" ,
42
+ "stride" ,
43
+ "expand_ratio" ,
44
+ "input_filters" ,
45
+ "output_filters" ,
46
+ "se_ratio" ,
47
+ "id_skip" ,
48
+ ],
49
+ )
50
+
51
+ # Set GlobalParams and BlockArgs's defaults
52
+ GlobalParams .__new__ .__defaults__ = (None ,) * len (GlobalParams ._fields )
53
+ BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
16
54
17
55
18
56
class MBConvBlock (nn .Module ):
@@ -29,77 +67,94 @@ class MBConvBlock(nn.Module):
29
67
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
30
68
"""
31
69
32
- def __init__ (self , block_args , global_params , image_size = None ):
70
+ def __init__ (
71
+ self , block_args : BlockArgs , global_params : GlobalParams , image_size = None
72
+ ):
33
73
super ().__init__ ()
34
- self ._block_args = block_args
35
- self ._bn_mom = (
36
- 1 - global_params .batch_norm_momentum
37
- ) # pytorch's difference from tensorflow
38
- self ._bn_eps = global_params .batch_norm_epsilon
39
- self .has_se = (self ._block_args .se_ratio is not None ) and (
40
- 0 < self ._block_args .se_ratio <= 1
41
- )
42
- self .id_skip = (
74
+
75
+ self ._has_expansion = block_args .expand_ratio != 1
76
+ self ._has_se = block_args .se_ratio is not None and 0 < block_args .se_ratio <= 1
77
+ self ._has_drop_connect = (
43
78
block_args .id_skip
44
- ) # whether to use skip connection and drop connect
79
+ and block_args .stride == 1
80
+ and block_args .input_filters == block_args .output_filters
81
+ )
82
+
83
+ # Pytorch's difference from tensorflow
84
+ bn_momentum = 1 - global_params .batch_norm_momentum
85
+ bn_eps = global_params .batch_norm_epsilon
45
86
46
87
# Expansion phase (Inverted Bottleneck)
47
- inp = self ._block_args .input_filters # number of input channels
48
- oup = (
49
- self ._block_args .input_filters * self ._block_args .expand_ratio
50
- ) # number of output channels
51
- if self ._block_args .expand_ratio != 1 :
88
+ input_channels = block_args .input_filters
89
+ expanded_channels = input_channels * block_args .expand_ratio
90
+
91
+ if self ._has_expansion :
52
92
Conv2d = get_same_padding_conv2d (image_size = image_size )
53
93
self ._expand_conv = Conv2d (
54
- in_channels = inp , out_channels = oup , kernel_size = 1 , bias = False
94
+ input_channels , expanded_channels , kernel_size = 1 , bias = False
55
95
)
56
96
self ._bn0 = nn .BatchNorm2d (
57
- num_features = oup , momentum = self ._bn_mom , eps = self ._bn_eps
97
+ expanded_channels ,
98
+ momentum = bn_momentum ,
99
+ eps = bn_eps ,
58
100
)
59
- # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
101
+ else :
102
+ # for torchscript compatibility
103
+ self ._expand_conv = nn .Identity ()
104
+ self ._bn0 = nn .Identity ()
60
105
61
106
# Depthwise convolution phase
62
- k = self . _block_args .kernel_size
63
- s = self . _block_args .stride
107
+ kernel_size = block_args .kernel_size
108
+ stride = block_args .stride
64
109
Conv2d = get_same_padding_conv2d (image_size = image_size )
65
110
self ._depthwise_conv = Conv2d (
66
- in_channels = oup ,
67
- out_channels = oup ,
68
- groups = oup , # groups makes it depthwise
69
- kernel_size = k ,
70
- stride = s ,
111
+ in_channels = expanded_channels ,
112
+ out_channels = expanded_channels ,
113
+ groups = expanded_channels , # groups makes it depthwise
114
+ kernel_size = kernel_size ,
115
+ stride = stride ,
71
116
bias = False ,
72
117
)
73
118
self ._bn1 = nn .BatchNorm2d (
74
- num_features = oup , momentum = self ._bn_mom , eps = self ._bn_eps
119
+ expanded_channels ,
120
+ momentum = bn_momentum ,
121
+ eps = bn_eps ,
75
122
)
76
- image_size = calculate_output_image_size (image_size , s )
123
+ image_size = calculate_output_image_size (image_size , stride )
77
124
78
125
# Squeeze and Excitation layer, if desired
79
- if self .has_se :
126
+ if self ._has_se :
127
+ squeezed_channels = int (input_channels * block_args .se_ratio )
128
+ squeezed_channels = max (1 , squeezed_channels )
80
129
Conv2d = get_same_padding_conv2d (image_size = (1 , 1 ))
81
- num_squeezed_channels = max (
82
- 1 , int (self ._block_args .input_filters * self ._block_args .se_ratio )
83
- )
84
130
self ._se_reduce = Conv2d (
85
- in_channels = oup , out_channels = num_squeezed_channels , kernel_size = 1
131
+ in_channels = expanded_channels ,
132
+ out_channels = squeezed_channels ,
133
+ kernel_size = 1 ,
86
134
)
87
135
self ._se_expand = Conv2d (
88
- in_channels = num_squeezed_channels , out_channels = oup , kernel_size = 1
136
+ in_channels = squeezed_channels ,
137
+ out_channels = expanded_channels ,
138
+ kernel_size = 1 ,
89
139
)
90
140
91
141
# Pointwise convolution phase
92
- final_oup = self . _block_args .output_filters
142
+ output_channels = block_args .output_filters
93
143
Conv2d = get_same_padding_conv2d (image_size = image_size )
94
144
self ._project_conv = Conv2d (
95
- in_channels = oup , out_channels = final_oup , kernel_size = 1 , bias = False
145
+ in_channels = expanded_channels ,
146
+ out_channels = output_channels ,
147
+ kernel_size = 1 ,
148
+ bias = False ,
96
149
)
97
150
self ._bn2 = nn .BatchNorm2d (
98
- num_features = final_oup , momentum = self ._bn_mom , eps = self ._bn_eps
151
+ num_features = output_channels ,
152
+ momentum = bn_momentum ,
153
+ eps = bn_eps ,
99
154
)
100
155
self ._swish = nn .SiLU ()
101
156
102
- def forward (self , inputs , drop_connect_rate = None ):
157
+ def forward (self , inputs : torch . Tensor , drop_connect_rate : Optional [ float ] = None ):
103
158
"""MBConvBlock's forward function.
104
159
105
160
Args:
@@ -112,7 +167,7 @@ def forward(self, inputs, drop_connect_rate=None):
112
167
113
168
# Expansion and Depthwise Convolution
114
169
x = inputs
115
- if self ._block_args . expand_ratio != 1 :
170
+ if self ._has_expansion :
116
171
x = self ._expand_conv (inputs )
117
172
x = self ._bn0 (x )
118
173
x = self ._swish (x )
@@ -122,7 +177,7 @@ def forward(self, inputs, drop_connect_rate=None):
122
177
x = self ._swish (x )
123
178
124
179
# Squeeze and Excitation
125
- if self .has_se :
180
+ if self ._has_se :
126
181
x_squeezed = F .adaptive_avg_pool2d (x , 1 )
127
182
x_squeezed = self ._se_reduce (x_squeezed )
128
183
x_squeezed = self ._swish (x_squeezed )
@@ -134,17 +189,9 @@ def forward(self, inputs, drop_connect_rate=None):
134
189
x = self ._bn2 (x )
135
190
136
191
# Skip connection and drop connect
137
- input_filters , output_filters = (
138
- self ._block_args .input_filters ,
139
- self ._block_args .output_filters ,
140
- )
141
- if (
142
- self .id_skip
143
- and self ._block_args .stride == 1
144
- and input_filters == output_filters
145
- ):
192
+ if self ._has_drop_connect :
146
193
# The combination of skip connection and drop connect brings about stochastic depth.
147
- if drop_connect_rate :
194
+ if drop_connect_rate is not None and drop_connect_rate > 0 :
148
195
x = drop_connect (x , p = drop_connect_rate , training = self .training )
149
196
x = x + inputs # skip connection
150
197
return x
@@ -169,10 +216,14 @@ class EfficientNet(nn.Module):
169
216
>>> outputs = model(inputs)
170
217
"""
171
218
172
- def __init__ (self , blocks_args = None , global_params = None ):
219
+ def __init__ (self , blocks_args : List [ BlockArgs ] , global_params : GlobalParams ):
173
220
super ().__init__ ()
174
- assert isinstance (blocks_args , list ), "blocks_args should be a list"
175
- assert len (blocks_args ) > 0 , "block args must be greater than 0"
221
+
222
+ if not isinstance (blocks_args , list ):
223
+ raise ValueError ("blocks_args should be a list" )
224
+ if len (blocks_args ) == 0 :
225
+ raise ValueError ("block args must be greater than 0" )
226
+
176
227
self ._global_params = global_params
177
228
self ._blocks_args = blocks_args
178
229
@@ -186,20 +237,16 @@ def __init__(self, blocks_args=None, global_params=None):
186
237
187
238
# Stem
188
239
in_channels = 3 # rgb
189
- out_channels = round_filters (
190
- 32 , self ._global_params
191
- ) # number of output channels
240
+ out_channels = round_filters (32 , self ._global_params )
192
241
self ._conv_stem = Conv2d (
193
242
in_channels , out_channels , kernel_size = 3 , stride = 2 , bias = False
194
243
)
195
- self ._bn0 = nn .BatchNorm2d (
196
- num_features = out_channels , momentum = bn_mom , eps = bn_eps
197
- )
244
+ self ._bn0 = nn .BatchNorm2d (out_channels , momentum = bn_mom , eps = bn_eps )
198
245
image_size = calculate_output_image_size (image_size , 2 )
199
246
200
247
# Build blocks
201
248
self ._blocks = nn .ModuleList ([])
202
- for block_args in self . _blocks_args :
249
+ for block_args in blocks_args :
203
250
# Update block input and output filters based on depth multiplier.
204
251
block_args = block_args ._replace (
205
252
input_filters = round_filters (
@@ -243,57 +290,8 @@ def __init__(self, blocks_args=None, global_params=None):
243
290
244
291
self ._swish = nn .SiLU ()
245
292
246
- def extract_endpoints (self , inputs ):
247
- """Use convolution layer to extract features
248
- from reduction levels i in [1, 2, 3, 4, 5].
249
-
250
- Args:
251
- inputs (tensor): Input tensor.
252
-
253
- Returns:
254
- Dictionary of last intermediate features
255
- with reduction levels i in [1, 2, 3, 4, 5].
256
- Example:
257
- >>> import torch
258
- >>> from efficientnet.model import EfficientNet
259
- >>> inputs = torch.rand(1, 3, 224, 224)
260
- >>> model = EfficientNet.from_pretrained('efficientnet-b0')
261
- >>> endpoints = model.extract_endpoints(inputs)
262
- >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
263
- >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
264
- >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
265
- >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
266
- >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
267
- >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
268
- """
269
- endpoints = dict ()
270
-
271
- # Stem
272
- x = self ._swish (self ._bn0 (self ._conv_stem (inputs )))
273
- prev_x = x
274
-
275
- # Blocks
276
- for idx , block in enumerate (self ._blocks ):
277
- drop_connect_rate = self ._global_params .drop_connect_rate
278
- if drop_connect_rate :
279
- drop_connect_rate *= float (idx ) / len (
280
- self ._blocks
281
- ) # scale drop connect_rate
282
- x = block (x , drop_connect_rate = drop_connect_rate )
283
- if prev_x .size (2 ) > x .size (2 ):
284
- endpoints ["reduction_{}" .format (len (endpoints ) + 1 )] = prev_x
285
- elif idx == len (self ._blocks ) - 1 :
286
- endpoints ["reduction_{}" .format (len (endpoints ) + 1 )] = x
287
- prev_x = x
288
-
289
- # Head
290
- x = self ._swish (self ._bn1 (self ._conv_head (x )))
291
- endpoints ["reduction_{}" .format (len (endpoints ) + 1 )] = x
292
-
293
- return endpoints
294
-
295
293
def extract_features (self , inputs ):
296
- """use convolution layer to extract feature .
294
+ """Use convolution layer to extract feature.
297
295
298
296
Args:
299
297
inputs (tensor): Input tensor.
@@ -309,9 +307,8 @@ def extract_features(self, inputs):
309
307
for idx , block in enumerate (self ._blocks ):
310
308
drop_connect_rate = self ._global_params .drop_connect_rate
311
309
if drop_connect_rate :
312
- drop_connect_rate *= float (idx ) / len (
313
- self ._blocks
314
- ) # scale drop connect_rate
310
+ # scale drop connect_rate
311
+ drop_connect_rate *= float (idx ) / len (self ._blocks )
315
312
x = block (x , drop_connect_rate = drop_connect_rate )
316
313
317
314
# Head
@@ -321,7 +318,7 @@ def extract_features(self, inputs):
321
318
322
319
def forward (self , inputs ):
323
320
"""EfficientNet's forward function.
324
- Calls extract_features to extract features, applies final linear layer, and returns logits.
321
+ Calls extract_features to extract features, applies final linear layer, and returns logits.
325
322
326
323
Args:
327
324
inputs (tensor): Input tensor.
@@ -331,6 +328,7 @@ def forward(self, inputs):
331
328
"""
332
329
# Convolution layers
333
330
x = self .extract_features (inputs )
331
+
334
332
# Pooling and final linear layer
335
333
x = self ._avg_pooling (x )
336
334
if self ._global_params .include_top :
@@ -358,43 +356,6 @@ def forward(self, inputs):
358
356
# It's an additional function, not used in EfficientNet,
359
357
# but can be used in other model (such as EfficientDet).
360
358
361
- # Parameters for the entire model (stem, all blocks, and head)
362
- GlobalParams = collections .namedtuple (
363
- "GlobalParams" ,
364
- [
365
- "width_coefficient" ,
366
- "depth_coefficient" ,
367
- "image_size" ,
368
- "dropout_rate" ,
369
- "num_classes" ,
370
- "batch_norm_momentum" ,
371
- "batch_norm_epsilon" ,
372
- "drop_connect_rate" ,
373
- "depth_divisor" ,
374
- "min_depth" ,
375
- "include_top" ,
376
- ],
377
- )
378
-
379
- # Parameters for an individual model block
380
- BlockArgs = collections .namedtuple (
381
- "BlockArgs" ,
382
- [
383
- "num_repeat" ,
384
- "kernel_size" ,
385
- "stride" ,
386
- "expand_ratio" ,
387
- "input_filters" ,
388
- "output_filters" ,
389
- "se_ratio" ,
390
- "id_skip" ,
391
- ],
392
- )
393
-
394
- # Set GlobalParams and BlockArgs's defaults
395
- GlobalParams .__new__ .__defaults__ = (None ,) * len (GlobalParams ._fields )
396
- BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
397
-
398
359
399
360
def round_filters (filters , global_params ):
400
361
"""Calculate and round number of filters based on width multiplier.
@@ -442,7 +403,7 @@ def round_repeats(repeats, global_params):
442
403
return int (math .ceil (multiplier * repeats ))
443
404
444
405
445
- def drop_connect (inputs , p , training ) :
406
+ def drop_connect (inputs : torch . Tensor , p : float , training : bool ) -> torch . Tensor :
446
407
"""Drop connect.
447
408
448
409
Args:
0 commit comments