1313import math
1414import collections
1515from functools import partial
16+ from typing import List , Optional
17+
18+ # Parameters for the entire model (stem, all blocks, and head)
19+ GlobalParams = collections .namedtuple (
20+ "GlobalParams" ,
21+ [
22+ "width_coefficient" ,
23+ "depth_coefficient" ,
24+ "image_size" ,
25+ "dropout_rate" ,
26+ "num_classes" ,
27+ "batch_norm_momentum" ,
28+ "batch_norm_epsilon" ,
29+ "drop_connect_rate" ,
30+ "depth_divisor" ,
31+ "min_depth" ,
32+ "include_top" ,
33+ ],
34+ )
35+
36+ # Parameters for an individual model block
37+ BlockArgs = collections .namedtuple (
38+ "BlockArgs" ,
39+ [
40+ "num_repeat" ,
41+ "kernel_size" ,
42+ "stride" ,
43+ "expand_ratio" ,
44+ "input_filters" ,
45+ "output_filters" ,
46+ "se_ratio" ,
47+ "id_skip" ,
48+ ],
49+ )
50+
51+ # Set GlobalParams and BlockArgs's defaults
52+ GlobalParams .__new__ .__defaults__ = (None ,) * len (GlobalParams ._fields )
53+ BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
1654
1755
1856class MBConvBlock (nn .Module ):
@@ -29,77 +67,94 @@ class MBConvBlock(nn.Module):
2967 [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
3068 """
3169
32- def __init__ (self , block_args , global_params , image_size = None ):
70+ def __init__ (
71+ self , block_args : BlockArgs , global_params : GlobalParams , image_size = None
72+ ):
3373 super ().__init__ ()
34- self ._block_args = block_args
35- self ._bn_mom = (
36- 1 - global_params .batch_norm_momentum
37- ) # pytorch's difference from tensorflow
38- self ._bn_eps = global_params .batch_norm_epsilon
39- self .has_se = (self ._block_args .se_ratio is not None ) and (
40- 0 < self ._block_args .se_ratio <= 1
41- )
42- self .id_skip = (
74+
75+ self ._has_expansion = block_args .expand_ratio != 1
76+ self ._has_se = block_args .se_ratio is not None and 0 < block_args .se_ratio <= 1
77+ self ._has_drop_connect = (
4378 block_args .id_skip
44- ) # whether to use skip connection and drop connect
79+ and block_args .stride == 1
80+ and block_args .input_filters == block_args .output_filters
81+ )
82+
83+ # Pytorch's difference from tensorflow
84+ bn_momentum = 1 - global_params .batch_norm_momentum
85+ bn_eps = global_params .batch_norm_epsilon
4586
4687 # Expansion phase (Inverted Bottleneck)
47- inp = self ._block_args .input_filters # number of input channels
48- oup = (
49- self ._block_args .input_filters * self ._block_args .expand_ratio
50- ) # number of output channels
51- if self ._block_args .expand_ratio != 1 :
88+ input_channels = block_args .input_filters
89+ expanded_channels = input_channels * block_args .expand_ratio
90+
91+ if self ._has_expansion :
5292 Conv2d = get_same_padding_conv2d (image_size = image_size )
5393 self ._expand_conv = Conv2d (
54- in_channels = inp , out_channels = oup , kernel_size = 1 , bias = False
94+ input_channels , expanded_channels , kernel_size = 1 , bias = False
5595 )
5696 self ._bn0 = nn .BatchNorm2d (
57- num_features = oup , momentum = self ._bn_mom , eps = self ._bn_eps
97+ expanded_channels ,
98+ momentum = bn_momentum ,
99+ eps = bn_eps ,
58100 )
59- # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
101+ else :
102+ # for torchscript compatibility
103+ self ._expand_conv = nn .Identity ()
104+ self ._bn0 = nn .Identity ()
60105
61106 # Depthwise convolution phase
62- k = self . _block_args .kernel_size
63- s = self . _block_args .stride
107+ kernel_size = block_args .kernel_size
108+ stride = block_args .stride
64109 Conv2d = get_same_padding_conv2d (image_size = image_size )
65110 self ._depthwise_conv = Conv2d (
66- in_channels = oup ,
67- out_channels = oup ,
68- groups = oup , # groups makes it depthwise
69- kernel_size = k ,
70- stride = s ,
111+ in_channels = expanded_channels ,
112+ out_channels = expanded_channels ,
113+ groups = expanded_channels , # groups makes it depthwise
114+ kernel_size = kernel_size ,
115+ stride = stride ,
71116 bias = False ,
72117 )
73118 self ._bn1 = nn .BatchNorm2d (
74- num_features = oup , momentum = self ._bn_mom , eps = self ._bn_eps
119+ expanded_channels ,
120+ momentum = bn_momentum ,
121+ eps = bn_eps ,
75122 )
76- image_size = calculate_output_image_size (image_size , s )
123+ image_size = calculate_output_image_size (image_size , stride )
77124
78125 # Squeeze and Excitation layer, if desired
79- if self .has_se :
126+ if self ._has_se :
127+ squeezed_channels = int (input_channels * block_args .se_ratio )
128+ squeezed_channels = max (1 , squeezed_channels )
80129 Conv2d = get_same_padding_conv2d (image_size = (1 , 1 ))
81- num_squeezed_channels = max (
82- 1 , int (self ._block_args .input_filters * self ._block_args .se_ratio )
83- )
84130 self ._se_reduce = Conv2d (
85- in_channels = oup , out_channels = num_squeezed_channels , kernel_size = 1
131+ in_channels = expanded_channels ,
132+ out_channels = squeezed_channels ,
133+ kernel_size = 1 ,
86134 )
87135 self ._se_expand = Conv2d (
88- in_channels = num_squeezed_channels , out_channels = oup , kernel_size = 1
136+ in_channels = squeezed_channels ,
137+ out_channels = expanded_channels ,
138+ kernel_size = 1 ,
89139 )
90140
91141 # Pointwise convolution phase
92- final_oup = self . _block_args .output_filters
142+ output_channels = block_args .output_filters
93143 Conv2d = get_same_padding_conv2d (image_size = image_size )
94144 self ._project_conv = Conv2d (
95- in_channels = oup , out_channels = final_oup , kernel_size = 1 , bias = False
145+ in_channels = expanded_channels ,
146+ out_channels = output_channels ,
147+ kernel_size = 1 ,
148+ bias = False ,
96149 )
97150 self ._bn2 = nn .BatchNorm2d (
98- num_features = final_oup , momentum = self ._bn_mom , eps = self ._bn_eps
151+ num_features = output_channels ,
152+ momentum = bn_momentum ,
153+ eps = bn_eps ,
99154 )
100155 self ._swish = nn .SiLU ()
101156
102- def forward (self , inputs , drop_connect_rate = None ):
157+ def forward (self , inputs : torch . Tensor , drop_connect_rate : Optional [ float ] = None ):
103158 """MBConvBlock's forward function.
104159
105160 Args:
@@ -112,7 +167,7 @@ def forward(self, inputs, drop_connect_rate=None):
112167
113168 # Expansion and Depthwise Convolution
114169 x = inputs
115- if self ._block_args . expand_ratio != 1 :
170+ if self ._has_expansion :
116171 x = self ._expand_conv (inputs )
117172 x = self ._bn0 (x )
118173 x = self ._swish (x )
@@ -122,7 +177,7 @@ def forward(self, inputs, drop_connect_rate=None):
122177 x = self ._swish (x )
123178
124179 # Squeeze and Excitation
125- if self .has_se :
180+ if self ._has_se :
126181 x_squeezed = F .adaptive_avg_pool2d (x , 1 )
127182 x_squeezed = self ._se_reduce (x_squeezed )
128183 x_squeezed = self ._swish (x_squeezed )
@@ -134,17 +189,9 @@ def forward(self, inputs, drop_connect_rate=None):
134189 x = self ._bn2 (x )
135190
136191 # Skip connection and drop connect
137- input_filters , output_filters = (
138- self ._block_args .input_filters ,
139- self ._block_args .output_filters ,
140- )
141- if (
142- self .id_skip
143- and self ._block_args .stride == 1
144- and input_filters == output_filters
145- ):
192+ if self ._has_drop_connect :
146193 # The combination of skip connection and drop connect brings about stochastic depth.
147- if drop_connect_rate :
194+ if drop_connect_rate is not None and drop_connect_rate > 0 :
148195 x = drop_connect (x , p = drop_connect_rate , training = self .training )
149196 x = x + inputs # skip connection
150197 return x
@@ -169,10 +216,14 @@ class EfficientNet(nn.Module):
169216 >>> outputs = model(inputs)
170217 """
171218
172- def __init__ (self , blocks_args = None , global_params = None ):
219+ def __init__ (self , blocks_args : List [ BlockArgs ] , global_params : GlobalParams ):
173220 super ().__init__ ()
174- assert isinstance (blocks_args , list ), "blocks_args should be a list"
175- assert len (blocks_args ) > 0 , "block args must be greater than 0"
221+
222+ if not isinstance (blocks_args , list ):
223+ raise ValueError ("blocks_args should be a list" )
224+ if len (blocks_args ) == 0 :
225+ raise ValueError ("block args must be greater than 0" )
226+
176227 self ._global_params = global_params
177228 self ._blocks_args = blocks_args
178229
@@ -186,20 +237,16 @@ def __init__(self, blocks_args=None, global_params=None):
186237
187238 # Stem
188239 in_channels = 3 # rgb
189- out_channels = round_filters (
190- 32 , self ._global_params
191- ) # number of output channels
240+ out_channels = round_filters (32 , self ._global_params )
192241 self ._conv_stem = Conv2d (
193242 in_channels , out_channels , kernel_size = 3 , stride = 2 , bias = False
194243 )
195- self ._bn0 = nn .BatchNorm2d (
196- num_features = out_channels , momentum = bn_mom , eps = bn_eps
197- )
244+ self ._bn0 = nn .BatchNorm2d (out_channels , momentum = bn_mom , eps = bn_eps )
198245 image_size = calculate_output_image_size (image_size , 2 )
199246
200247 # Build blocks
201248 self ._blocks = nn .ModuleList ([])
202- for block_args in self . _blocks_args :
249+ for block_args in blocks_args :
203250 # Update block input and output filters based on depth multiplier.
204251 block_args = block_args ._replace (
205252 input_filters = round_filters (
@@ -243,57 +290,8 @@ def __init__(self, blocks_args=None, global_params=None):
243290
244291 self ._swish = nn .SiLU ()
245292
246- def extract_endpoints (self , inputs ):
247- """Use convolution layer to extract features
248- from reduction levels i in [1, 2, 3, 4, 5].
249-
250- Args:
251- inputs (tensor): Input tensor.
252-
253- Returns:
254- Dictionary of last intermediate features
255- with reduction levels i in [1, 2, 3, 4, 5].
256- Example:
257- >>> import torch
258- >>> from efficientnet.model import EfficientNet
259- >>> inputs = torch.rand(1, 3, 224, 224)
260- >>> model = EfficientNet.from_pretrained('efficientnet-b0')
261- >>> endpoints = model.extract_endpoints(inputs)
262- >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
263- >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
264- >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
265- >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
266- >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
267- >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
268- """
269- endpoints = dict ()
270-
271- # Stem
272- x = self ._swish (self ._bn0 (self ._conv_stem (inputs )))
273- prev_x = x
274-
275- # Blocks
276- for idx , block in enumerate (self ._blocks ):
277- drop_connect_rate = self ._global_params .drop_connect_rate
278- if drop_connect_rate :
279- drop_connect_rate *= float (idx ) / len (
280- self ._blocks
281- ) # scale drop connect_rate
282- x = block (x , drop_connect_rate = drop_connect_rate )
283- if prev_x .size (2 ) > x .size (2 ):
284- endpoints ["reduction_{}" .format (len (endpoints ) + 1 )] = prev_x
285- elif idx == len (self ._blocks ) - 1 :
286- endpoints ["reduction_{}" .format (len (endpoints ) + 1 )] = x
287- prev_x = x
288-
289- # Head
290- x = self ._swish (self ._bn1 (self ._conv_head (x )))
291- endpoints ["reduction_{}" .format (len (endpoints ) + 1 )] = x
292-
293- return endpoints
294-
295293 def extract_features (self , inputs ):
296- """use convolution layer to extract feature .
294+ """Use convolution layer to extract feature.
297295
298296 Args:
299297 inputs (tensor): Input tensor.
@@ -309,9 +307,8 @@ def extract_features(self, inputs):
309307 for idx , block in enumerate (self ._blocks ):
310308 drop_connect_rate = self ._global_params .drop_connect_rate
311309 if drop_connect_rate :
312- drop_connect_rate *= float (idx ) / len (
313- self ._blocks
314- ) # scale drop connect_rate
310+ # scale drop connect_rate
311+ drop_connect_rate *= float (idx ) / len (self ._blocks )
315312 x = block (x , drop_connect_rate = drop_connect_rate )
316313
317314 # Head
@@ -321,7 +318,7 @@ def extract_features(self, inputs):
321318
322319 def forward (self , inputs ):
323320 """EfficientNet's forward function.
324- Calls extract_features to extract features, applies final linear layer, and returns logits.
321+ Calls extract_features to extract features, applies final linear layer, and returns logits.
325322
326323 Args:
327324 inputs (tensor): Input tensor.
@@ -331,6 +328,7 @@ def forward(self, inputs):
331328 """
332329 # Convolution layers
333330 x = self .extract_features (inputs )
331+
334332 # Pooling and final linear layer
335333 x = self ._avg_pooling (x )
336334 if self ._global_params .include_top :
@@ -358,43 +356,6 @@ def forward(self, inputs):
358356# It's an additional function, not used in EfficientNet,
359357# but can be used in other model (such as EfficientDet).
360358
361- # Parameters for the entire model (stem, all blocks, and head)
362- GlobalParams = collections .namedtuple (
363- "GlobalParams" ,
364- [
365- "width_coefficient" ,
366- "depth_coefficient" ,
367- "image_size" ,
368- "dropout_rate" ,
369- "num_classes" ,
370- "batch_norm_momentum" ,
371- "batch_norm_epsilon" ,
372- "drop_connect_rate" ,
373- "depth_divisor" ,
374- "min_depth" ,
375- "include_top" ,
376- ],
377- )
378-
379- # Parameters for an individual model block
380- BlockArgs = collections .namedtuple (
381- "BlockArgs" ,
382- [
383- "num_repeat" ,
384- "kernel_size" ,
385- "stride" ,
386- "expand_ratio" ,
387- "input_filters" ,
388- "output_filters" ,
389- "se_ratio" ,
390- "id_skip" ,
391- ],
392- )
393-
394- # Set GlobalParams and BlockArgs's defaults
395- GlobalParams .__new__ .__defaults__ = (None ,) * len (GlobalParams ._fields )
396- BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
397-
398359
399360def round_filters (filters , global_params ):
400361 """Calculate and round number of filters based on width multiplier.
@@ -442,7 +403,7 @@ def round_repeats(repeats, global_params):
442403 return int (math .ceil (multiplier * repeats ))
443404
444405
445- def drop_connect (inputs , p , training ) :
406+ def drop_connect (inputs : torch . Tensor , p : float , training : bool ) -> torch . Tensor :
446407 """Drop connect.
447408
448409 Args:
0 commit comments