@@ -113,7 +113,7 @@ def __init__(self, block_args, global_params, image_size=None):
113113 self ._bn2 = nn .BatchNorm2d (
114114 num_features = final_oup , momentum = self ._bn_mom , eps = self ._bn_eps
115115 )
116- self ._swish = MemoryEfficientSwish ()
116+ self ._swish = nn . SiLU ()
117117
118118 def forward (self , inputs , drop_connect_rate = None ):
119119 """MBConvBlock's forward function.
@@ -165,14 +165,6 @@ def forward(self, inputs, drop_connect_rate=None):
165165 x = x + inputs # skip connection
166166 return x
167167
168- def set_swish (self , memory_efficient = True ):
169- """Sets swish function as memory efficient (for training) or standard (for export).
170-
171- Args:
172- memory_efficient (bool): Whether to use memory-efficient version of swish.
173- """
174- self ._swish = MemoryEfficientSwish () if memory_efficient else nn .SiLU ()
175-
176168
177169class EfficientNet (nn .Module ):
178170 """EfficientNet model.
@@ -265,18 +257,7 @@ def __init__(self, blocks_args=None, global_params=None):
265257 self ._dropout = nn .Dropout (self ._global_params .dropout_rate )
266258 self ._fc = nn .Linear (out_channels , self ._global_params .num_classes )
267259
268- # set activation to memory efficient swish by default
269- self ._swish = MemoryEfficientSwish ()
270-
271- def set_swish (self , memory_efficient = True ):
272- """Sets swish function as memory efficient (for training) or standard (for export).
273-
274- Args:
275- memory_efficient (bool): Whether to use memory-efficient version of swish.
276- """
277- self ._swish = MemoryEfficientSwish () if memory_efficient else nn .SiLU ()
278- for block in self ._blocks :
279- block .set_swish (memory_efficient )
260+ self ._swish = nn .SiLU ()
280261
281262 def extract_endpoints (self , inputs ):
282263 """Use convolution layer to extract features
@@ -380,7 +361,6 @@ def forward(self, inputs):
380361################################################################################
381362
382363# GlobalParams and BlockArgs: Two namedtuples
383- # Swish and MemoryEfficientSwish: Two implementations of the method
384364# round_filters and round_repeats:
385365# Functions to calculate params for scaling model width and depth ! ! !
386366# get_width_and_height_from_size and calculate_output_image_size
@@ -432,26 +412,6 @@ def forward(self, inputs):
432412BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
433413
434414
435- # A memory-efficient implementation of Swish function
436- class SwishImplementation (torch .autograd .Function ):
437- @staticmethod
438- def forward (ctx , i ):
439- result = i * torch .sigmoid (i )
440- ctx .save_for_backward (i )
441- return result
442-
443- @staticmethod
444- def backward (ctx , grad_output ):
445- i = ctx .saved_tensors [0 ]
446- sigmoid_i = torch .sigmoid (i )
447- return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i )))
448-
449-
450- class MemoryEfficientSwish (nn .Module ):
451- def forward (self , x ):
452- return SwishImplementation .apply (x )
453-
454-
455415def round_filters (filters , global_params ):
456416 """Calculate and round number of filters based on width multiplier.
457417 Use width_coefficient, depth_divisor and min_depth of global_params.
0 commit comments