@@ -171,12 +171,11 @@ def set_swish(self, memory_efficient=True):
171171 Args:
172172 memory_efficient (bool): Whether to use memory-efficient version of swish.
173173 """
174- self ._swish = MemoryEfficientSwish () if memory_efficient else Swish ()
174+ self ._swish = MemoryEfficientSwish () if memory_efficient else nn . SiLU ()
175175
176176
177177class EfficientNet (nn .Module ):
178178 """EfficientNet model.
179- Most easily loaded with the .from_name or .from_pretrained methods.
180179
181180 Args:
182181 blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
@@ -275,7 +274,7 @@ def set_swish(self, memory_efficient=True):
275274 Args:
276275 memory_efficient (bool): Whether to use memory-efficient version of swish.
277276 """
278- self ._swish = MemoryEfficientSwish () if memory_efficient else Swish ()
277+ self ._swish = MemoryEfficientSwish () if memory_efficient else nn . SiLU ()
279278 for block in self ._blocks :
280279 block .set_swish (memory_efficient )
281280
@@ -375,127 +374,6 @@ def forward(self, inputs):
375374 x = self ._fc (x )
376375 return x
377376
378- @classmethod
379- def from_name (cls , model_name , in_channels = 3 , ** override_params ):
380- """Create an efficientnet model according to name.
381-
382- Args:
383- model_name (str): Name for efficientnet.
384- in_channels (int): Input data's channel number.
385- override_params (other key word params):
386- Params to override model's global_params.
387- Optional key:
388- 'width_coefficient', 'depth_coefficient',
389- 'image_size', 'dropout_rate',
390- 'num_classes', 'batch_norm_momentum',
391- 'batch_norm_epsilon', 'drop_connect_rate',
392- 'depth_divisor', 'min_depth'
393-
394- Returns:
395- An efficientnet model.
396- """
397- cls ._check_model_name_is_valid (model_name )
398- blocks_args , global_params = get_model_params (model_name , override_params )
399- model = cls (blocks_args , global_params )
400- model ._change_in_channels (in_channels )
401- return model
402-
403- @classmethod
404- def from_pretrained (
405- cls ,
406- model_name ,
407- weights_path = None ,
408- advprop = False ,
409- in_channels = 3 ,
410- num_classes = 1000 ,
411- ** override_params ,
412- ):
413- """Create an efficientnet model according to name.
414-
415- Args:
416- model_name (str): Name for efficientnet.
417- weights_path (None or str):
418- str: path to pretrained weights file on the local disk.
419- None: use pretrained weights downloaded from the Internet.
420- advprop (bool):
421- Whether to load pretrained weights
422- trained with advprop (valid when weights_path is None).
423- in_channels (int): Input data's channel number.
424- num_classes (int):
425- Number of categories for classification.
426- It controls the output size for final linear layer.
427- override_params (other key word params):
428- Params to override model's global_params.
429- Optional key:
430- 'width_coefficient', 'depth_coefficient',
431- 'image_size', 'dropout_rate',
432- 'batch_norm_momentum',
433- 'batch_norm_epsilon', 'drop_connect_rate',
434- 'depth_divisor', 'min_depth'
435-
436- Returns:
437- A pretrained efficientnet model.
438- """
439- model = cls .from_name (model_name , num_classes = num_classes , ** override_params )
440- load_pretrained_weights (
441- model ,
442- model_name ,
443- weights_path = weights_path ,
444- load_fc = (num_classes == 1000 ),
445- advprop = advprop ,
446- )
447- model ._change_in_channels (in_channels )
448- return model
449-
450- @classmethod
451- def get_image_size (cls , model_name ):
452- """Get the input image size for a given efficientnet model.
453-
454- Args:
455- model_name (str): Name for efficientnet.
456-
457- Returns:
458- Input image size (resolution).
459- """
460- cls ._check_model_name_is_valid (model_name )
461- _ , _ , res , _ = efficientnet_params (model_name )
462- return res
463-
464- @classmethod
465- def _check_model_name_is_valid (cls , model_name ):
466- """Validates model name.
467-
468- Args:
469- model_name (str): Name for efficientnet.
470-
471- Returns:
472- bool: Is a valid name or not.
473- """
474- if model_name not in VALID_MODELS :
475- raise ValueError ("model_name should be one of: " + ", " .join (VALID_MODELS ))
476-
477- def _change_in_channels (self , in_channels ):
478- """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
479-
480- Args:
481- in_channels (int): Input data's channel number.
482- """
483- if in_channels != 3 :
484- Conv2d = get_same_padding_conv2d (image_size = self ._global_params .image_size )
485- out_channels = round_filters (32 , self ._global_params )
486- self ._conv_stem = Conv2d (
487- in_channels , out_channels , kernel_size = 3 , stride = 2 , bias = False
488- )
489-
490-
491- """utils.py - Helper functions for building the model and for loading model parameters.
492- These helper functions are built to mirror those in the official TensorFlow implementation.
493- """
494-
495- # Author: lukemelas (github username)
496- # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
497- # With adjustments and added comments by workingcoder (github username).
498-
499377
500378################################################################################
501379# Help functions for model architecture
@@ -553,15 +431,6 @@ def _change_in_channels(self, in_channels):
553431GlobalParams .__new__ .__defaults__ = (None ,) * len (GlobalParams ._fields )
554432BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
555433
556- # Swish activation function
557- if hasattr (nn , "SiLU" ):
558- Swish = nn .SiLU
559- else :
560- # For compatibility with old PyTorch versions
561- class Swish (nn .Module ):
562- def forward (self , x ):
563- return x * torch .sigmoid (x )
564-
565434
566435# A memory-efficient implementation of Swish function
567436class SwishImplementation (torch .autograd .Function ):
0 commit comments