Skip to content

Commit ce1ae43

Browse files
committed
Refactor effnet for torchscript
1 parent 28877ed commit ce1ae43

File tree

1 file changed

+113
-152
lines changed

1 file changed

+113
-152
lines changed

segmentation_models_pytorch/encoders/_efficientnet.py

Lines changed: 113 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,44 @@
1313
import math
1414
import collections
1515
from functools import partial
16+
from typing import List, Optional
17+
18+
# Parameters for the entire model (stem, all blocks, and head)
19+
GlobalParams = collections.namedtuple(
20+
"GlobalParams",
21+
[
22+
"width_coefficient",
23+
"depth_coefficient",
24+
"image_size",
25+
"dropout_rate",
26+
"num_classes",
27+
"batch_norm_momentum",
28+
"batch_norm_epsilon",
29+
"drop_connect_rate",
30+
"depth_divisor",
31+
"min_depth",
32+
"include_top",
33+
],
34+
)
35+
36+
# Parameters for an individual model block
37+
BlockArgs = collections.namedtuple(
38+
"BlockArgs",
39+
[
40+
"num_repeat",
41+
"kernel_size",
42+
"stride",
43+
"expand_ratio",
44+
"input_filters",
45+
"output_filters",
46+
"se_ratio",
47+
"id_skip",
48+
],
49+
)
50+
51+
# Set GlobalParams and BlockArgs's defaults
52+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
53+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
1654

1755

1856
class MBConvBlock(nn.Module):
@@ -29,77 +67,94 @@ class MBConvBlock(nn.Module):
2967
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
3068
"""
3169

32-
def __init__(self, block_args, global_params, image_size=None):
70+
def __init__(
71+
self, block_args: BlockArgs, global_params: GlobalParams, image_size=None
72+
):
3373
super().__init__()
34-
self._block_args = block_args
35-
self._bn_mom = (
36-
1 - global_params.batch_norm_momentum
37-
) # pytorch's difference from tensorflow
38-
self._bn_eps = global_params.batch_norm_epsilon
39-
self.has_se = (self._block_args.se_ratio is not None) and (
40-
0 < self._block_args.se_ratio <= 1
41-
)
42-
self.id_skip = (
74+
75+
self._has_expansion = block_args.expand_ratio != 1
76+
self._has_se = block_args.se_ratio is not None and 0 < block_args.se_ratio <= 1
77+
self._has_drop_connect = (
4378
block_args.id_skip
44-
) # whether to use skip connection and drop connect
79+
and block_args.stride == 1
80+
and block_args.input_filters == block_args.output_filters
81+
)
82+
83+
# Pytorch's difference from tensorflow
84+
bn_momentum = 1 - global_params.batch_norm_momentum
85+
bn_eps = global_params.batch_norm_epsilon
4586

4687
# Expansion phase (Inverted Bottleneck)
47-
inp = self._block_args.input_filters # number of input channels
48-
oup = (
49-
self._block_args.input_filters * self._block_args.expand_ratio
50-
) # number of output channels
51-
if self._block_args.expand_ratio != 1:
88+
input_channels = block_args.input_filters
89+
expanded_channels = input_channels * block_args.expand_ratio
90+
91+
if self._has_expansion:
5292
Conv2d = get_same_padding_conv2d(image_size=image_size)
5393
self._expand_conv = Conv2d(
54-
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
94+
input_channels, expanded_channels, kernel_size=1, bias=False
5595
)
5696
self._bn0 = nn.BatchNorm2d(
57-
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
97+
expanded_channels,
98+
momentum=bn_momentum,
99+
eps=bn_eps,
58100
)
59-
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
101+
else:
102+
# for torchscript compatibility
103+
self._expand_conv = nn.Identity()
104+
self._bn0 = nn.Identity()
60105

61106
# Depthwise convolution phase
62-
k = self._block_args.kernel_size
63-
s = self._block_args.stride
107+
kernel_size = block_args.kernel_size
108+
stride = block_args.stride
64109
Conv2d = get_same_padding_conv2d(image_size=image_size)
65110
self._depthwise_conv = Conv2d(
66-
in_channels=oup,
67-
out_channels=oup,
68-
groups=oup, # groups makes it depthwise
69-
kernel_size=k,
70-
stride=s,
111+
in_channels=expanded_channels,
112+
out_channels=expanded_channels,
113+
groups=expanded_channels, # groups makes it depthwise
114+
kernel_size=kernel_size,
115+
stride=stride,
71116
bias=False,
72117
)
73118
self._bn1 = nn.BatchNorm2d(
74-
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
119+
expanded_channels,
120+
momentum=bn_momentum,
121+
eps=bn_eps,
75122
)
76-
image_size = calculate_output_image_size(image_size, s)
123+
image_size = calculate_output_image_size(image_size, stride)
77124

78125
# Squeeze and Excitation layer, if desired
79-
if self.has_se:
126+
if self._has_se:
127+
squeezed_channels = int(input_channels * block_args.se_ratio)
128+
squeezed_channels = max(1, squeezed_channels)
80129
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
81-
num_squeezed_channels = max(
82-
1, int(self._block_args.input_filters * self._block_args.se_ratio)
83-
)
84130
self._se_reduce = Conv2d(
85-
in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
131+
in_channels=expanded_channels,
132+
out_channels=squeezed_channels,
133+
kernel_size=1,
86134
)
87135
self._se_expand = Conv2d(
88-
in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
136+
in_channels=squeezed_channels,
137+
out_channels=expanded_channels,
138+
kernel_size=1,
89139
)
90140

91141
# Pointwise convolution phase
92-
final_oup = self._block_args.output_filters
142+
output_channels = block_args.output_filters
93143
Conv2d = get_same_padding_conv2d(image_size=image_size)
94144
self._project_conv = Conv2d(
95-
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
145+
in_channels=expanded_channels,
146+
out_channels=output_channels,
147+
kernel_size=1,
148+
bias=False,
96149
)
97150
self._bn2 = nn.BatchNorm2d(
98-
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
151+
num_features=output_channels,
152+
momentum=bn_momentum,
153+
eps=bn_eps,
99154
)
100155
self._swish = nn.SiLU()
101156

102-
def forward(self, inputs, drop_connect_rate=None):
157+
def forward(self, inputs: torch.Tensor, drop_connect_rate: Optional[float] = None):
103158
"""MBConvBlock's forward function.
104159
105160
Args:
@@ -112,7 +167,7 @@ def forward(self, inputs, drop_connect_rate=None):
112167

113168
# Expansion and Depthwise Convolution
114169
x = inputs
115-
if self._block_args.expand_ratio != 1:
170+
if self._has_expansion:
116171
x = self._expand_conv(inputs)
117172
x = self._bn0(x)
118173
x = self._swish(x)
@@ -122,7 +177,7 @@ def forward(self, inputs, drop_connect_rate=None):
122177
x = self._swish(x)
123178

124179
# Squeeze and Excitation
125-
if self.has_se:
180+
if self._has_se:
126181
x_squeezed = F.adaptive_avg_pool2d(x, 1)
127182
x_squeezed = self._se_reduce(x_squeezed)
128183
x_squeezed = self._swish(x_squeezed)
@@ -134,17 +189,9 @@ def forward(self, inputs, drop_connect_rate=None):
134189
x = self._bn2(x)
135190

136191
# Skip connection and drop connect
137-
input_filters, output_filters = (
138-
self._block_args.input_filters,
139-
self._block_args.output_filters,
140-
)
141-
if (
142-
self.id_skip
143-
and self._block_args.stride == 1
144-
and input_filters == output_filters
145-
):
192+
if self._has_drop_connect:
146193
# The combination of skip connection and drop connect brings about stochastic depth.
147-
if drop_connect_rate:
194+
if drop_connect_rate is not None and drop_connect_rate > 0:
148195
x = drop_connect(x, p=drop_connect_rate, training=self.training)
149196
x = x + inputs # skip connection
150197
return x
@@ -169,10 +216,14 @@ class EfficientNet(nn.Module):
169216
>>> outputs = model(inputs)
170217
"""
171218

172-
def __init__(self, blocks_args=None, global_params=None):
219+
def __init__(self, blocks_args: List[BlockArgs], global_params: GlobalParams):
173220
super().__init__()
174-
assert isinstance(blocks_args, list), "blocks_args should be a list"
175-
assert len(blocks_args) > 0, "block args must be greater than 0"
221+
222+
if not isinstance(blocks_args, list):
223+
raise ValueError("blocks_args should be a list")
224+
if len(blocks_args) == 0:
225+
raise ValueError("block args must be greater than 0")
226+
176227
self._global_params = global_params
177228
self._blocks_args = blocks_args
178229

@@ -186,20 +237,16 @@ def __init__(self, blocks_args=None, global_params=None):
186237

187238
# Stem
188239
in_channels = 3 # rgb
189-
out_channels = round_filters(
190-
32, self._global_params
191-
) # number of output channels
240+
out_channels = round_filters(32, self._global_params)
192241
self._conv_stem = Conv2d(
193242
in_channels, out_channels, kernel_size=3, stride=2, bias=False
194243
)
195-
self._bn0 = nn.BatchNorm2d(
196-
num_features=out_channels, momentum=bn_mom, eps=bn_eps
197-
)
244+
self._bn0 = nn.BatchNorm2d(out_channels, momentum=bn_mom, eps=bn_eps)
198245
image_size = calculate_output_image_size(image_size, 2)
199246

200247
# Build blocks
201248
self._blocks = nn.ModuleList([])
202-
for block_args in self._blocks_args:
249+
for block_args in blocks_args:
203250
# Update block input and output filters based on depth multiplier.
204251
block_args = block_args._replace(
205252
input_filters=round_filters(
@@ -243,57 +290,8 @@ def __init__(self, blocks_args=None, global_params=None):
243290

244291
self._swish = nn.SiLU()
245292

246-
def extract_endpoints(self, inputs):
247-
"""Use convolution layer to extract features
248-
from reduction levels i in [1, 2, 3, 4, 5].
249-
250-
Args:
251-
inputs (tensor): Input tensor.
252-
253-
Returns:
254-
Dictionary of last intermediate features
255-
with reduction levels i in [1, 2, 3, 4, 5].
256-
Example:
257-
>>> import torch
258-
>>> from efficientnet.model import EfficientNet
259-
>>> inputs = torch.rand(1, 3, 224, 224)
260-
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
261-
>>> endpoints = model.extract_endpoints(inputs)
262-
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
263-
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
264-
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
265-
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
266-
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
267-
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
268-
"""
269-
endpoints = dict()
270-
271-
# Stem
272-
x = self._swish(self._bn0(self._conv_stem(inputs)))
273-
prev_x = x
274-
275-
# Blocks
276-
for idx, block in enumerate(self._blocks):
277-
drop_connect_rate = self._global_params.drop_connect_rate
278-
if drop_connect_rate:
279-
drop_connect_rate *= float(idx) / len(
280-
self._blocks
281-
) # scale drop connect_rate
282-
x = block(x, drop_connect_rate=drop_connect_rate)
283-
if prev_x.size(2) > x.size(2):
284-
endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
285-
elif idx == len(self._blocks) - 1:
286-
endpoints["reduction_{}".format(len(endpoints) + 1)] = x
287-
prev_x = x
288-
289-
# Head
290-
x = self._swish(self._bn1(self._conv_head(x)))
291-
endpoints["reduction_{}".format(len(endpoints) + 1)] = x
292-
293-
return endpoints
294-
295293
def extract_features(self, inputs):
296-
"""use convolution layer to extract feature .
294+
"""Use convolution layer to extract feature.
297295
298296
Args:
299297
inputs (tensor): Input tensor.
@@ -309,9 +307,8 @@ def extract_features(self, inputs):
309307
for idx, block in enumerate(self._blocks):
310308
drop_connect_rate = self._global_params.drop_connect_rate
311309
if drop_connect_rate:
312-
drop_connect_rate *= float(idx) / len(
313-
self._blocks
314-
) # scale drop connect_rate
310+
# scale drop connect_rate
311+
drop_connect_rate *= float(idx) / len(self._blocks)
315312
x = block(x, drop_connect_rate=drop_connect_rate)
316313

317314
# Head
@@ -321,7 +318,7 @@ def extract_features(self, inputs):
321318

322319
def forward(self, inputs):
323320
"""EfficientNet's forward function.
324-
Calls extract_features to extract features, applies final linear layer, and returns logits.
321+
Calls extract_features to extract features, applies final linear layer, and returns logits.
325322
326323
Args:
327324
inputs (tensor): Input tensor.
@@ -331,6 +328,7 @@ def forward(self, inputs):
331328
"""
332329
# Convolution layers
333330
x = self.extract_features(inputs)
331+
334332
# Pooling and final linear layer
335333
x = self._avg_pooling(x)
336334
if self._global_params.include_top:
@@ -358,43 +356,6 @@ def forward(self, inputs):
358356
# It's an additional function, not used in EfficientNet,
359357
# but can be used in other model (such as EfficientDet).
360358

361-
# Parameters for the entire model (stem, all blocks, and head)
362-
GlobalParams = collections.namedtuple(
363-
"GlobalParams",
364-
[
365-
"width_coefficient",
366-
"depth_coefficient",
367-
"image_size",
368-
"dropout_rate",
369-
"num_classes",
370-
"batch_norm_momentum",
371-
"batch_norm_epsilon",
372-
"drop_connect_rate",
373-
"depth_divisor",
374-
"min_depth",
375-
"include_top",
376-
],
377-
)
378-
379-
# Parameters for an individual model block
380-
BlockArgs = collections.namedtuple(
381-
"BlockArgs",
382-
[
383-
"num_repeat",
384-
"kernel_size",
385-
"stride",
386-
"expand_ratio",
387-
"input_filters",
388-
"output_filters",
389-
"se_ratio",
390-
"id_skip",
391-
],
392-
)
393-
394-
# Set GlobalParams and BlockArgs's defaults
395-
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
396-
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
397-
398359

399360
def round_filters(filters, global_params):
400361
"""Calculate and round number of filters based on width multiplier.
@@ -442,7 +403,7 @@ def round_repeats(repeats, global_params):
442403
return int(math.ceil(multiplier * repeats))
443404

444405

445-
def drop_connect(inputs, p, training):
406+
def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor:
446407
"""Drop connect.
447408
448409
Args:

0 commit comments

Comments
 (0)