Skip to content

Commit 29f0bbf

Browse files
committed
Remove classmethods, PyTorch compatibility layer
1 parent 3635e19 commit 29f0bbf

File tree

1 file changed

+2
-133
lines changed

1 file changed

+2
-133
lines changed

segmentation_models_pytorch/encoders/_efficientnet.py

Lines changed: 2 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,11 @@ def set_swish(self, memory_efficient=True):
171171
Args:
172172
memory_efficient (bool): Whether to use memory-efficient version of swish.
173173
"""
174-
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
174+
self._swish = MemoryEfficientSwish() if memory_efficient else nn.SiLU()
175175

176176

177177
class EfficientNet(nn.Module):
178178
"""EfficientNet model.
179-
Most easily loaded with the .from_name or .from_pretrained methods.
180179
181180
Args:
182181
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
@@ -275,7 +274,7 @@ def set_swish(self, memory_efficient=True):
275274
Args:
276275
memory_efficient (bool): Whether to use memory-efficient version of swish.
277276
"""
278-
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
277+
self._swish = MemoryEfficientSwish() if memory_efficient else nn.SiLU()
279278
for block in self._blocks:
280279
block.set_swish(memory_efficient)
281280

@@ -375,127 +374,6 @@ def forward(self, inputs):
375374
x = self._fc(x)
376375
return x
377376

378-
@classmethod
379-
def from_name(cls, model_name, in_channels=3, **override_params):
380-
"""Create an efficientnet model according to name.
381-
382-
Args:
383-
model_name (str): Name for efficientnet.
384-
in_channels (int): Input data's channel number.
385-
override_params (other key word params):
386-
Params to override model's global_params.
387-
Optional key:
388-
'width_coefficient', 'depth_coefficient',
389-
'image_size', 'dropout_rate',
390-
'num_classes', 'batch_norm_momentum',
391-
'batch_norm_epsilon', 'drop_connect_rate',
392-
'depth_divisor', 'min_depth'
393-
394-
Returns:
395-
An efficientnet model.
396-
"""
397-
cls._check_model_name_is_valid(model_name)
398-
blocks_args, global_params = get_model_params(model_name, override_params)
399-
model = cls(blocks_args, global_params)
400-
model._change_in_channels(in_channels)
401-
return model
402-
403-
@classmethod
404-
def from_pretrained(
405-
cls,
406-
model_name,
407-
weights_path=None,
408-
advprop=False,
409-
in_channels=3,
410-
num_classes=1000,
411-
**override_params,
412-
):
413-
"""Create an efficientnet model according to name.
414-
415-
Args:
416-
model_name (str): Name for efficientnet.
417-
weights_path (None or str):
418-
str: path to pretrained weights file on the local disk.
419-
None: use pretrained weights downloaded from the Internet.
420-
advprop (bool):
421-
Whether to load pretrained weights
422-
trained with advprop (valid when weights_path is None).
423-
in_channels (int): Input data's channel number.
424-
num_classes (int):
425-
Number of categories for classification.
426-
It controls the output size for final linear layer.
427-
override_params (other key word params):
428-
Params to override model's global_params.
429-
Optional key:
430-
'width_coefficient', 'depth_coefficient',
431-
'image_size', 'dropout_rate',
432-
'batch_norm_momentum',
433-
'batch_norm_epsilon', 'drop_connect_rate',
434-
'depth_divisor', 'min_depth'
435-
436-
Returns:
437-
A pretrained efficientnet model.
438-
"""
439-
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
440-
load_pretrained_weights(
441-
model,
442-
model_name,
443-
weights_path=weights_path,
444-
load_fc=(num_classes == 1000),
445-
advprop=advprop,
446-
)
447-
model._change_in_channels(in_channels)
448-
return model
449-
450-
@classmethod
451-
def get_image_size(cls, model_name):
452-
"""Get the input image size for a given efficientnet model.
453-
454-
Args:
455-
model_name (str): Name for efficientnet.
456-
457-
Returns:
458-
Input image size (resolution).
459-
"""
460-
cls._check_model_name_is_valid(model_name)
461-
_, _, res, _ = efficientnet_params(model_name)
462-
return res
463-
464-
@classmethod
465-
def _check_model_name_is_valid(cls, model_name):
466-
"""Validates model name.
467-
468-
Args:
469-
model_name (str): Name for efficientnet.
470-
471-
Returns:
472-
bool: Is a valid name or not.
473-
"""
474-
if model_name not in VALID_MODELS:
475-
raise ValueError("model_name should be one of: " + ", ".join(VALID_MODELS))
476-
477-
def _change_in_channels(self, in_channels):
478-
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
479-
480-
Args:
481-
in_channels (int): Input data's channel number.
482-
"""
483-
if in_channels != 3:
484-
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
485-
out_channels = round_filters(32, self._global_params)
486-
self._conv_stem = Conv2d(
487-
in_channels, out_channels, kernel_size=3, stride=2, bias=False
488-
)
489-
490-
491-
"""utils.py - Helper functions for building the model and for loading model parameters.
492-
These helper functions are built to mirror those in the official TensorFlow implementation.
493-
"""
494-
495-
# Author: lukemelas (github username)
496-
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
497-
# With adjustments and added comments by workingcoder (github username).
498-
499377

500378
################################################################################
501379
# Help functions for model architecture
@@ -553,15 +431,6 @@ def _change_in_channels(self, in_channels):
553431
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
554432
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
555433

556-
# Swish activation function
557-
if hasattr(nn, "SiLU"):
558-
Swish = nn.SiLU
559-
else:
560-
# For compatibility with old PyTorch versions
561-
class Swish(nn.Module):
562-
def forward(self, x):
563-
return x * torch.sigmoid(x)
564-
565434

566435
# A memory-efficient implementation of Swish function
567436
class SwishImplementation(torch.autograd.Function):

0 commit comments

Comments
 (0)