@@ -113,7 +113,7 @@ def __init__(self, block_args, global_params, image_size=None):
113
113
self ._bn2 = nn .BatchNorm2d (
114
114
num_features = final_oup , momentum = self ._bn_mom , eps = self ._bn_eps
115
115
)
116
- self ._swish = MemoryEfficientSwish ()
116
+ self ._swish = nn . SiLU ()
117
117
118
118
def forward (self , inputs , drop_connect_rate = None ):
119
119
"""MBConvBlock's forward function.
@@ -165,14 +165,6 @@ def forward(self, inputs, drop_connect_rate=None):
165
165
x = x + inputs # skip connection
166
166
return x
167
167
168
- def set_swish (self , memory_efficient = True ):
169
- """Sets swish function as memory efficient (for training) or standard (for export).
170
-
171
- Args:
172
- memory_efficient (bool): Whether to use memory-efficient version of swish.
173
- """
174
- self ._swish = MemoryEfficientSwish () if memory_efficient else nn .SiLU ()
175
-
176
168
177
169
class EfficientNet (nn .Module ):
178
170
"""EfficientNet model.
@@ -265,18 +257,7 @@ def __init__(self, blocks_args=None, global_params=None):
265
257
self ._dropout = nn .Dropout (self ._global_params .dropout_rate )
266
258
self ._fc = nn .Linear (out_channels , self ._global_params .num_classes )
267
259
268
- # set activation to memory efficient swish by default
269
- self ._swish = MemoryEfficientSwish ()
270
-
271
- def set_swish (self , memory_efficient = True ):
272
- """Sets swish function as memory efficient (for training) or standard (for export).
273
-
274
- Args:
275
- memory_efficient (bool): Whether to use memory-efficient version of swish.
276
- """
277
- self ._swish = MemoryEfficientSwish () if memory_efficient else nn .SiLU ()
278
- for block in self ._blocks :
279
- block .set_swish (memory_efficient )
260
+ self ._swish = nn .SiLU ()
280
261
281
262
def extract_endpoints (self , inputs ):
282
263
"""Use convolution layer to extract features
@@ -380,7 +361,6 @@ def forward(self, inputs):
380
361
################################################################################
381
362
382
363
# GlobalParams and BlockArgs: Two namedtuples
383
- # Swish and MemoryEfficientSwish: Two implementations of the method
384
364
# round_filters and round_repeats:
385
365
# Functions to calculate params for scaling model width and depth ! ! !
386
366
# get_width_and_height_from_size and calculate_output_image_size
@@ -432,26 +412,6 @@ def forward(self, inputs):
432
412
BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
433
413
434
414
435
- # A memory-efficient implementation of Swish function
436
- class SwishImplementation (torch .autograd .Function ):
437
- @staticmethod
438
- def forward (ctx , i ):
439
- result = i * torch .sigmoid (i )
440
- ctx .save_for_backward (i )
441
- return result
442
-
443
- @staticmethod
444
- def backward (ctx , grad_output ):
445
- i = ctx .saved_tensors [0 ]
446
- sigmoid_i = torch .sigmoid (i )
447
- return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i )))
448
-
449
-
450
- class MemoryEfficientSwish (nn .Module ):
451
- def forward (self , x ):
452
- return SwishImplementation .apply (x )
453
-
454
-
455
415
def round_filters (filters , global_params ):
456
416
"""Calculate and round number of filters based on width multiplier.
457
417
Use width_coefficient, depth_divisor and min_depth of global_params.
0 commit comments