Skip to content
This repository was archived by the owner on Dec 8, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, block_args, global_params):
in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
self._bn0 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)

# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
Expand Down Expand Up @@ -176,6 +177,8 @@ def __init__(self, blocks_args=None, global_params=None):
num_features=out_channels, momentum=bn_mom, eps=bn_eps)

# Final linear layer
# Won't be used here as this backbone is only used to output features maps
# As inputs for FPN / biFPN in object detection tasks
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
Expand All @@ -188,7 +191,7 @@ def set_swish(self, memory_efficient=True):
block.set_swish(memory_efficient)

def extract_features(self, inputs):
""" Returns output of the final convolution layer """
"""Blocks repeat themselves, only keep the last output of each repeated phase"""
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))

Expand All @@ -202,17 +205,17 @@ def extract_features(self, inputs):
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
num_repeat = num_repeat + 1
# add the last output map after block repetition
if(num_repeat == self._blocks_args[index].num_repeat):
num_repeat = 0
index = index + 1
P.append(x)
return P

def forward(self, inputs):
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """
""" Calls extract_features to extract features and returns the features maps of each blocks (after repetition)"""
# Convolution layers
P = self.extract_features(inputs)
return P
return self.extract_features(inputs)

@classmethod
def from_name(cls, model_name, override_params=None):
Expand All @@ -222,11 +225,11 @@ def from_name(cls, model_name, override_params=None):
return cls(blocks_args, global_params)

@classmethod
def from_pretrained(cls, model_name, num_classes=1000, in_channels=3):
def from_pretrained(cls, model_name, num_classes=1000, in_channels=3, advprop=False):
model = cls.from_name(model_name, override_params={
'num_classes': num_classes})
load_pretrained_weights(
model, model_name, load_fc=(num_classes == 1000))
model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
if in_channels != 3:
Conv2d = get_same_padding_conv2d(
image_size=model._global_params.image_size)
Expand All @@ -235,14 +238,6 @@ def from_pretrained(cls, model_name, num_classes=1000, in_channels=3):
in_channels, out_channels, kernel_size=3, stride=2, bias=False)
return model

@classmethod
def from_pretrained(cls, model_name, num_classes=1000):
model = cls.from_name(model_name, override_params={
'num_classes': num_classes})
load_pretrained_weights(
model, model_name, load_fc=(num_classes == 1000))

return model

@classmethod
def get_image_size(cls, model_name):
Expand All @@ -251,11 +246,11 @@ def get_image_size(cls, model_name):
return res

@classmethod
def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
""" Validates model name. None that pretrained weights are only available for
the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
num_models = 4 if also_need_pretrained_weights else 8
valid_models = ['efficientnet-b'+str(i) for i in range(num_models)]
def _check_model_name_is_valid(cls, model_name):
""" Validates model name. Note that only pretrained weights
with adverseral training are available for EfficientNet-B8.
"""
valid_models = ['efficientnet-b'+str(i) for i in range(9)]
if model_name not in valid_models:
raise ValueError('model_name should be one of: ' +
', '.join(valid_models))
Expand All @@ -274,4 +269,6 @@ def get_list_features(self):
P = model(inputs)
for idx, p in enumerate(P):
print('P{}: {}'.format(idx, p.size()))
# print('model: ', model)
verbose = False
if verbose:
print('model: ', model)
39 changes: 27 additions & 12 deletions models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,27 +302,42 @@ def get_model_params(model_name, override_params):
return blocks_args, global_params



url_map = {
'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth',
'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth',
'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth',
'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth',
'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth',
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth',
'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth',
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth',
'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth',
'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth',
'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth',
'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth',
'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth',
'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth',
'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth',
}


url_map_advprop = {
'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth',
'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth',
'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth',
'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth',
'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth',
'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth',
'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth',
'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth',
'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth',
}


def load_pretrained_weights(model, model_name, load_fc=True):
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
""" Loads pretrained weights, and downloads if loading for the first time. """
state_dict = model_zoo.load_url(url_map[model_name])
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map_advprop if advprop else url_map
state_dict = model_zoo.load_url(url_map_[model_name])
if load_fc:
model.load_state_dict(state_dict)
else:
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
res = model.load_state_dict(state_dict, strict=False)
assert set(res.missing_keys) == set(
['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
print('Loaded pretrained weights for {}'.format(model_name))