16
16
from torch .utils import model_zoo
17
17
18
18
19
- VALID_MODELS = (
20
- "efficientnet-b0" ,
21
- "efficientnet-b1" ,
22
- "efficientnet-b2" ,
23
- "efficientnet-b3" ,
24
- "efficientnet-b4" ,
25
- "efficientnet-b5" ,
26
- "efficientnet-b6" ,
27
- "efficientnet-b7" ,
28
- "efficientnet-b8" ,
29
- # Support the construction of 'efficientnet-l2' without pretrained weights
30
- "efficientnet-l2" ,
31
- )
32
-
33
-
34
19
class MBConvBlock (nn .Module ):
35
20
"""Mobile Inverted Residual Bottleneck Block.
36
21
@@ -772,7 +757,6 @@ def forward(self, x):
772
757
# get_model_params and efficientnet:
773
758
# Functions to get BlockArgs and GlobalParams for efficientnet
774
759
# url_map and url_map_advprop: Dicts of url_map for pretrained weights
775
- # load_pretrained_weights: A function to load pretrained weights
776
760
777
761
778
762
class BlockDecoder (object ):
@@ -817,30 +801,6 @@ def _decode_block_string(block_string):
817
801
id_skip = ("noskip" not in block_string ),
818
802
)
819
803
820
- @staticmethod
821
- def _encode_block_string (block ):
822
- """Encode a block to a string.
823
-
824
- Args:
825
- block (namedtuple): A BlockArgs type argument.
826
-
827
- Returns:
828
- block_string: A String form of BlockArgs.
829
- """
830
- args = [
831
- "r%d" % block .num_repeat ,
832
- "k%d" % block .kernel_size ,
833
- "s%d%d" % (block .strides [0 ], block .strides [1 ]),
834
- "e%s" % block .expand_ratio ,
835
- "i%d" % block .input_filters ,
836
- "o%d" % block .output_filters ,
837
- ]
838
- if 0 < block .se_ratio <= 1 :
839
- args .append ("se%s" % block .se_ratio )
840
- if block .id_skip is False :
841
- args .append ("noskip" )
842
- return "_" .join (args )
843
-
844
804
@staticmethod
845
805
def decode (string_list ):
846
806
"""Decode a list of string notations to specify blocks inside the network.
@@ -857,21 +817,6 @@ def decode(string_list):
857
817
blocks_args .append (BlockDecoder ._decode_block_string (block_string ))
858
818
return blocks_args
859
819
860
- @staticmethod
861
- def encode (blocks_args ):
862
- """Encode a list of BlockArgs to a list of strings.
863
-
864
- Args:
865
- blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
866
-
867
- Returns:
868
- block_strings: A list of strings, each string is a notation of block.
869
- """
870
- block_strings = []
871
- for block in blocks_args :
872
- block_strings .append (BlockDecoder ._encode_block_string (block ))
873
- return block_strings
874
-
875
820
876
821
def efficientnet_params (model_name ):
877
822
"""Map EfficientNet model name to parameter coefficients.
@@ -1005,47 +950,3 @@ def get_model_params(model_name, override_params):
1005
950
"efficientnet-b7" : "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth" ,
1006
951
"efficientnet-b8" : "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth" ,
1007
952
}
1008
-
1009
- # TODO: add the petrained weights url map of 'efficientnet-l2'
1010
-
1011
-
1012
- def load_pretrained_weights (
1013
- model , model_name , weights_path = None , load_fc = True , advprop = False , verbose = True
1014
- ):
1015
- """Loads pretrained weights from weights path or download using url.
1016
-
1017
- Args:
1018
- model (Module): The whole model of efficientnet.
1019
- model_name (str): Model name of efficientnet.
1020
- weights_path (None or str):
1021
- str: path to pretrained weights file on the local disk.
1022
- None: use pretrained weights downloaded from the Internet.
1023
- load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
1024
- advprop (bool): Whether to load pretrained weights
1025
- trained with advprop (valid when weights_path is None).
1026
- """
1027
- if isinstance (weights_path , str ):
1028
- state_dict = torch .load (weights_path )
1029
- else :
1030
- # AutoAugment or Advprop (different preprocessing)
1031
- url_map_ = url_map_advprop if advprop else url_map
1032
- state_dict = model_zoo .load_url (url_map_ [model_name ])
1033
-
1034
- if load_fc :
1035
- ret = model .load_state_dict (state_dict , strict = False )
1036
- assert not ret .missing_keys , (
1037
- "Missing keys when loading pretrained weights: {}" .format (ret .missing_keys )
1038
- )
1039
- else :
1040
- state_dict .pop ("_fc.weight" )
1041
- state_dict .pop ("_fc.bias" )
1042
- ret = model .load_state_dict (state_dict , strict = False )
1043
- assert set (ret .missing_keys ) == set (["_fc.weight" , "_fc.bias" ]), (
1044
- "Missing keys when loading pretrained weights: {}" .format (ret .missing_keys )
1045
- )
1046
- assert not ret .unexpected_keys , (
1047
- "Missing keys when loading pretrained weights: {}" .format (ret .unexpected_keys )
1048
- )
1049
-
1050
- if verbose :
1051
- print ("Loaded pretrained weights for {}" .format (model_name ))
0 commit comments