@@ -171,12 +171,11 @@ def set_swish(self, memory_efficient=True):
171
171
Args:
172
172
memory_efficient (bool): Whether to use memory-efficient version of swish.
173
173
"""
174
- self ._swish = MemoryEfficientSwish () if memory_efficient else Swish ()
174
+ self ._swish = MemoryEfficientSwish () if memory_efficient else nn . SiLU ()
175
175
176
176
177
177
class EfficientNet (nn .Module ):
178
178
"""EfficientNet model.
179
- Most easily loaded with the .from_name or .from_pretrained methods.
180
179
181
180
Args:
182
181
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
@@ -275,7 +274,7 @@ def set_swish(self, memory_efficient=True):
275
274
Args:
276
275
memory_efficient (bool): Whether to use memory-efficient version of swish.
277
276
"""
278
- self ._swish = MemoryEfficientSwish () if memory_efficient else Swish ()
277
+ self ._swish = MemoryEfficientSwish () if memory_efficient else nn . SiLU ()
279
278
for block in self ._blocks :
280
279
block .set_swish (memory_efficient )
281
280
@@ -375,127 +374,6 @@ def forward(self, inputs):
375
374
x = self ._fc (x )
376
375
return x
377
376
378
- @classmethod
379
- def from_name (cls , model_name , in_channels = 3 , ** override_params ):
380
- """Create an efficientnet model according to name.
381
-
382
- Args:
383
- model_name (str): Name for efficientnet.
384
- in_channels (int): Input data's channel number.
385
- override_params (other key word params):
386
- Params to override model's global_params.
387
- Optional key:
388
- 'width_coefficient', 'depth_coefficient',
389
- 'image_size', 'dropout_rate',
390
- 'num_classes', 'batch_norm_momentum',
391
- 'batch_norm_epsilon', 'drop_connect_rate',
392
- 'depth_divisor', 'min_depth'
393
-
394
- Returns:
395
- An efficientnet model.
396
- """
397
- cls ._check_model_name_is_valid (model_name )
398
- blocks_args , global_params = get_model_params (model_name , override_params )
399
- model = cls (blocks_args , global_params )
400
- model ._change_in_channels (in_channels )
401
- return model
402
-
403
- @classmethod
404
- def from_pretrained (
405
- cls ,
406
- model_name ,
407
- weights_path = None ,
408
- advprop = False ,
409
- in_channels = 3 ,
410
- num_classes = 1000 ,
411
- ** override_params ,
412
- ):
413
- """Create an efficientnet model according to name.
414
-
415
- Args:
416
- model_name (str): Name for efficientnet.
417
- weights_path (None or str):
418
- str: path to pretrained weights file on the local disk.
419
- None: use pretrained weights downloaded from the Internet.
420
- advprop (bool):
421
- Whether to load pretrained weights
422
- trained with advprop (valid when weights_path is None).
423
- in_channels (int): Input data's channel number.
424
- num_classes (int):
425
- Number of categories for classification.
426
- It controls the output size for final linear layer.
427
- override_params (other key word params):
428
- Params to override model's global_params.
429
- Optional key:
430
- 'width_coefficient', 'depth_coefficient',
431
- 'image_size', 'dropout_rate',
432
- 'batch_norm_momentum',
433
- 'batch_norm_epsilon', 'drop_connect_rate',
434
- 'depth_divisor', 'min_depth'
435
-
436
- Returns:
437
- A pretrained efficientnet model.
438
- """
439
- model = cls .from_name (model_name , num_classes = num_classes , ** override_params )
440
- load_pretrained_weights (
441
- model ,
442
- model_name ,
443
- weights_path = weights_path ,
444
- load_fc = (num_classes == 1000 ),
445
- advprop = advprop ,
446
- )
447
- model ._change_in_channels (in_channels )
448
- return model
449
-
450
- @classmethod
451
- def get_image_size (cls , model_name ):
452
- """Get the input image size for a given efficientnet model.
453
-
454
- Args:
455
- model_name (str): Name for efficientnet.
456
-
457
- Returns:
458
- Input image size (resolution).
459
- """
460
- cls ._check_model_name_is_valid (model_name )
461
- _ , _ , res , _ = efficientnet_params (model_name )
462
- return res
463
-
464
- @classmethod
465
- def _check_model_name_is_valid (cls , model_name ):
466
- """Validates model name.
467
-
468
- Args:
469
- model_name (str): Name for efficientnet.
470
-
471
- Returns:
472
- bool: Is a valid name or not.
473
- """
474
- if model_name not in VALID_MODELS :
475
- raise ValueError ("model_name should be one of: " + ", " .join (VALID_MODELS ))
476
-
477
- def _change_in_channels (self , in_channels ):
478
- """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
479
-
480
- Args:
481
- in_channels (int): Input data's channel number.
482
- """
483
- if in_channels != 3 :
484
- Conv2d = get_same_padding_conv2d (image_size = self ._global_params .image_size )
485
- out_channels = round_filters (32 , self ._global_params )
486
- self ._conv_stem = Conv2d (
487
- in_channels , out_channels , kernel_size = 3 , stride = 2 , bias = False
488
- )
489
-
490
-
491
- """utils.py - Helper functions for building the model and for loading model parameters.
492
- These helper functions are built to mirror those in the official TensorFlow implementation.
493
- """
494
-
495
- # Author: lukemelas (github username)
496
- # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
497
- # With adjustments and added comments by workingcoder (github username).
498
-
499
377
500
378
################################################################################
501
379
# Help functions for model architecture
@@ -553,15 +431,6 @@ def _change_in_channels(self, in_channels):
553
431
GlobalParams .__new__ .__defaults__ = (None ,) * len (GlobalParams ._fields )
554
432
BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
555
433
556
- # Swish activation function
557
- if hasattr (nn , "SiLU" ):
558
- Swish = nn .SiLU
559
- else :
560
- # For compatibility with old PyTorch versions
561
- class Swish (nn .Module ):
562
- def forward (self , x ):
563
- return x * torch .sigmoid (x )
564
-
565
434
566
435
# A memory-efficient implementation of Swish function
567
436
class SwishImplementation (torch .autograd .Function ):
0 commit comments