Skip to content

Commit 002ae3e

Browse files
committed
Standard Swish implementation
1 parent 29f0bbf commit 002ae3e

File tree

1 file changed

+2
-42
lines changed

1 file changed

+2
-42
lines changed

segmentation_models_pytorch/encoders/_efficientnet.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(self, block_args, global_params, image_size=None):
113113
self._bn2 = nn.BatchNorm2d(
114114
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
115115
)
116-
self._swish = MemoryEfficientSwish()
116+
self._swish = nn.SiLU()
117117

118118
def forward(self, inputs, drop_connect_rate=None):
119119
"""MBConvBlock's forward function.
@@ -165,14 +165,6 @@ def forward(self, inputs, drop_connect_rate=None):
165165
x = x + inputs # skip connection
166166
return x
167167

168-
def set_swish(self, memory_efficient=True):
169-
"""Sets swish function as memory efficient (for training) or standard (for export).
170-
171-
Args:
172-
memory_efficient (bool): Whether to use memory-efficient version of swish.
173-
"""
174-
self._swish = MemoryEfficientSwish() if memory_efficient else nn.SiLU()
175-
176168

177169
class EfficientNet(nn.Module):
178170
"""EfficientNet model.
@@ -265,18 +257,7 @@ def __init__(self, blocks_args=None, global_params=None):
265257
self._dropout = nn.Dropout(self._global_params.dropout_rate)
266258
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
267259

268-
# set activation to memory efficient swish by default
269-
self._swish = MemoryEfficientSwish()
270-
271-
def set_swish(self, memory_efficient=True):
272-
"""Sets swish function as memory efficient (for training) or standard (for export).
273-
274-
Args:
275-
memory_efficient (bool): Whether to use memory-efficient version of swish.
276-
"""
277-
self._swish = MemoryEfficientSwish() if memory_efficient else nn.SiLU()
278-
for block in self._blocks:
279-
block.set_swish(memory_efficient)
260+
self._swish = nn.SiLU()
280261

281262
def extract_endpoints(self, inputs):
282263
"""Use convolution layer to extract features
@@ -380,7 +361,6 @@ def forward(self, inputs):
380361
################################################################################
381362

382363
# GlobalParams and BlockArgs: Two namedtuples
383-
# Swish and MemoryEfficientSwish: Two implementations of the method
384364
# round_filters and round_repeats:
385365
# Functions to calculate params for scaling model width and depth ! ! !
386366
# get_width_and_height_from_size and calculate_output_image_size
@@ -432,26 +412,6 @@ def forward(self, inputs):
432412
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
433413

434414

435-
# A memory-efficient implementation of Swish function
436-
class SwishImplementation(torch.autograd.Function):
437-
@staticmethod
438-
def forward(ctx, i):
439-
result = i * torch.sigmoid(i)
440-
ctx.save_for_backward(i)
441-
return result
442-
443-
@staticmethod
444-
def backward(ctx, grad_output):
445-
i = ctx.saved_tensors[0]
446-
sigmoid_i = torch.sigmoid(i)
447-
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
448-
449-
450-
class MemoryEfficientSwish(nn.Module):
451-
def forward(self, x):
452-
return SwishImplementation.apply(x)
453-
454-
455415
def round_filters(filters, global_params):
456416
"""Calculate and round number of filters based on width multiplier.
457417
Use width_coefficient, depth_divisor and min_depth of global_params.

0 commit comments

Comments
 (0)