Skip to content

Commit c533bc2

Browse files
committed
supports efficientdet-d7x now
1 parent 99ba5d9 commit c533bc2

File tree

4 files changed

+74
-29
lines changed

4 files changed

+74
-29
lines changed

backbone.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# Author: Zylo117
22

3-
import math
4-
53
import torch
64
from torch import nn
75

@@ -14,12 +12,13 @@ def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs
1412
super(EfficientDetBackbone, self).__init__()
1513
self.compound_coef = compound_coef
1614

17-
self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6]
18-
self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384]
19-
self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8]
20-
self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
21-
self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5]
22-
self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.]
15+
self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7]
16+
self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384]
17+
self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8]
18+
self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
19+
self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5]
20+
self.pyramid_levels = [5, 5, 5, 5, 5, 5, 5, 5, 6]
21+
self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5., 4.]
2322
self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
2423
self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
2524
conv_channel_coef = {
@@ -32,6 +31,7 @@ def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs
3231
5: [64, 176, 512],
3332
6: [72, 200, 576],
3433
7: [72, 200, 576],
34+
8: [80, 224, 640],
3535
}
3636

3737
num_anchors = len(self.aspect_ratios) * self.num_scales
@@ -40,17 +40,22 @@ def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs
4040
*[BiFPN(self.fpn_num_filters[self.compound_coef],
4141
conv_channel_coef[compound_coef],
4242
True if _ == 0 else False,
43-
attention=True if compound_coef < 6 else False)
43+
attention=True if compound_coef < 6 else False,
44+
use_p8=compound_coef > 7)
4445
for _ in range(self.fpn_cell_repeats[compound_coef])])
4546

4647
self.num_classes = num_classes
4748
self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
48-
num_layers=self.box_class_repeats[self.compound_coef])
49+
num_layers=self.box_class_repeats[self.compound_coef],
50+
pyramid_levels=self.pyramid_levels[self.compound_coef])
4951
self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
5052
num_classes=num_classes,
51-
num_layers=self.box_class_repeats[self.compound_coef])
53+
num_layers=self.box_class_repeats[self.compound_coef],
54+
pyramid_levels=self.pyramid_levels[self.compound_coef])
5255

53-
self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef], **kwargs)
56+
self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef],
57+
pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(),
58+
**kwargs)
5459

5560
self.backbone_net = EfficientNet(self.backbone_compound_coef[compound_coef], load_weights)
5661

efficientdet/model.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class BiFPN(nn.Module):
5757
modified by Zylo117
5858
"""
5959

60-
def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4, onnx_export=False, attention=True):
60+
def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4, onnx_export=False, attention=True,
61+
use_p8=False):
6162
"""
6263
6364
Args:
@@ -70,6 +71,8 @@ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4,
7071
"""
7172
super(BiFPN, self).__init__()
7273
self.epsilon = epsilon
74+
self.use_p8 = use_p8
75+
7376
# Conv layers
7477
self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
7578
self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
@@ -79,6 +82,9 @@ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4,
7982
self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
8083
self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
8184
self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
85+
if use_p8:
86+
self.conv7_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
87+
self.conv8_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
8288

8389
# Feature scaling layers
8490
self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
@@ -90,6 +96,9 @@ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4,
9096
self.p5_downsample = MaxPool2dStaticSamePadding(3, 2)
9197
self.p6_downsample = MaxPool2dStaticSamePadding(3, 2)
9298
self.p7_downsample = MaxPool2dStaticSamePadding(3, 2)
99+
if use_p8:
100+
self.p7_upsample = nn.Upsample(scale_factor=2, mode='nearest')
101+
self.p8_downsample = MaxPool2dStaticSamePadding(3, 2)
93102

94103
self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
95104

@@ -116,6 +125,10 @@ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4,
116125
self.p6_to_p7 = nn.Sequential(
117126
MaxPool2dStaticSamePadding(3, 2)
118127
)
128+
if use_p8:
129+
self.p7_to_p8 = nn.Sequential(
130+
MaxPool2dStaticSamePadding(3, 2)
131+
)
119132

120133
self.p4_down_channel_2 = nn.Sequential(
121134
Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
@@ -172,11 +185,11 @@ def forward(self, inputs):
172185
# elif later phase, upsample to target phase's by nearest interpolation
173186

174187
if self.attention:
175-
p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs)
188+
outs = self._forward_fast_attention(inputs)
176189
else:
177-
p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs)
190+
outs = self._forward(inputs)
178191

179-
return p3_out, p4_out, p5_out, p6_out, p7_out
192+
return outs
180193

181194
def _forward_fast_attention(self, inputs):
182195
if self.first_time:
@@ -258,19 +271,34 @@ def _forward(self, inputs):
258271

259272
p6_in = self.p5_to_p6(p5)
260273
p7_in = self.p6_to_p7(p6_in)
274+
if self.use_p8:
275+
p8_in = self.p7_to_p8(p7_in)
261276

262277
p3_in = self.p3_down_channel(p3)
263278
p4_in = self.p4_down_channel(p4)
264279
p5_in = self.p5_down_channel(p5)
265280

266281
else:
267-
# P3_0, P4_0, P5_0, P6_0 and P7_0
268-
p3_in, p4_in, p5_in, p6_in, p7_in = inputs
282+
if self.use_p8:
283+
# P3_0, P4_0, P5_0, P6_0, P7_0 and P8_0
284+
p3_in, p4_in, p5_in, p6_in, p7_in, p8_in = inputs
285+
else:
286+
# P3_0, P4_0, P5_0, P6_0 and P7_0
287+
p3_in, p4_in, p5_in, p6_in, p7_in = inputs
269288

270-
# P7_0 to P7_2
289+
if self.use_p8:
290+
# P8_0 to P8_2
271291

272-
# Connections for P6_0 and P7_0 to P6_1 respectively
273-
p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
292+
# Connections for P7_0 and P8_0 to P7_1 respectively
293+
p7_up = self.conv7_up(self.swish(p7_in + self.p7_upsample(p8_in)))
294+
295+
# Connections for P6_0 and P7_0 to P6_1 respectively
296+
p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_up)))
297+
else:
298+
# P7_0 to P7_2
299+
300+
# Connections for P6_0 and P7_0 to P6_1 respectively
301+
p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
274302

275303
# Connections for P5_0 and P6_1 to P5_1 respectively
276304
p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up)))
@@ -297,26 +325,36 @@ def _forward(self, inputs):
297325
p6_out = self.conv6_down(
298326
self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
299327

300-
# Connections for P7_0 and P6_2 to P7_2
301-
p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))
328+
if self.use_p8:
329+
# Connections for P7_0, P7_1 and P6_2 to P7_2 respectively
330+
p7_out = self.conv7_down(
331+
self.swish(p7_in + p7_up + self.p7_downsample(p6_out)))
302332

303-
return p3_out, p4_out, p5_out, p6_out, p7_out
333+
# Connections for P8_0 and P7_2 to P8_2
334+
p8_out = self.conv8_down(self.swish(p8_in + self.p8_downsample(p7_out)))
335+
336+
return p3_out, p4_out, p5_out, p6_out, p7_out, p8_out
337+
else:
338+
# Connections for P7_0 and P6_2 to P7_2
339+
p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))
340+
341+
return p3_out, p4_out, p5_out, p6_out, p7_out
304342

305343

306344
class Regressor(nn.Module):
307345
"""
308346
modified by Zylo117
309347
"""
310348

311-
def __init__(self, in_channels, num_anchors, num_layers, onnx_export=False):
349+
def __init__(self, in_channels, num_anchors, num_layers, pyramid_levels=5, onnx_export=False):
312350
super(Regressor, self).__init__()
313351
self.num_layers = num_layers
314352

315353
self.conv_list = nn.ModuleList(
316354
[SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
317355
self.bn_list = nn.ModuleList(
318356
[nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
319-
range(5)])
357+
range(pyramid_levels)])
320358
self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False)
321359
self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
322360

@@ -344,7 +382,7 @@ class Classifier(nn.Module):
344382
modified by Zylo117
345383
"""
346384

347-
def __init__(self, in_channels, num_anchors, num_classes, num_layers, onnx_export=False):
385+
def __init__(self, in_channels, num_anchors, num_classes, num_layers, pyramid_levels=5, onnx_export=False):
348386
super(Classifier, self).__init__()
349387
self.num_anchors = num_anchors
350388
self.num_classes = num_classes
@@ -353,7 +391,7 @@ def __init__(self, in_channels, num_anchors, num_classes, num_layers, onnx_expor
353391
[SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
354392
self.bn_list = nn.ModuleList(
355393
[nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
356-
range(5)])
394+
range(pyramid_levels)])
357395
self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False)
358396
self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
359397

efficientdet/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs):
6363

6464
if pyramid_levels is None:
6565
self.pyramid_levels = [3, 4, 5, 6, 7]
66+
else:
67+
self.pyramid_levels = pyramid_levels
6668

6769
self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels])
6870
self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))

efficientdet_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
color_list = standard_to_bgr(STANDARD_COLORS)
4747
# tf bilinear interpolation is different from any other's, just make do
48-
input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
48+
input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
4949
input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size
5050
ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size)
5151

0 commit comments

Comments
 (0)