Skip to content

Commit 4ef9c15

Browse files
committed
Add new model weights for bilinear/bilinear FPN up/down sample. Add support for training models with mish or other activations with torchscript considerations.
1 parent 36a232b commit 4ef9c15

File tree

5 files changed

+68
-32
lines changed

5 files changed

+68
-32
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ Aside from the default model configs, there is a lot of flexibility to facilitat
1616

1717
## Updates
1818

19+
### 2021-02-18
20+
* Add some new model weights with bilinear interpolation for upsample and downsample in FPN.
21+
* 40.9 mAP - `efficientdet_q1` (replace prev model at 40.6)
22+
* 43.2 mAP -`cspresdet50`
23+
* 45.2 mAP - `cspdarkdet5m`
24+
1925
### 2020-12-07
2026
* Training w/ fully jit scripted model + bench (`--torchscript`) is possible with inclusion of ModelEmaV2 from `timm` and previous torchscript compat additions. Big speed gains for CPU bound training.
2127
* Add weights for alternate FPN layouts. QuadFPN experiments (`efficientdet_q0/q1/q2`) and CSPResDeXt + PAN (`cspresdext50pan`). See updated table below. Special thanks to [Artus](https://twitter.com/artuskg) for providing resources for training the Q2 model.
@@ -114,11 +120,13 @@ The table below contains models with pretrained weights. There are quite a numbe
114120
| efficientdet_q0.pth | 35.7 | TBD | N/A | N/A | 4.13 |
115121
| efficientdet_d1.pth | 39.4 | 39.5 | 39.1 | 39.6 | 6.62 |
116122
| tf_efficientdet_d1.pth | 40.1 | TBD | 40.2 | 40.5 | 6.63 |
117-
| efficientdet_q1.pth | 40.6 | TBD | N/A | N/A | 6.98 |
123+
| efficientdet_q1.pth | 40.9 | TBD | N/A | N/A | 6.98 |
118124
| cspresdext50pan | 41.2 | TBD | N/A | N/A | 22.2 |
119125
| resdet50 | 41.6 | TBD | N/A | N/A | 27.6 |
120126
| efficientdet_q2.pth | 43.1 | TBD | N/A | N/A | 8.81 |
127+
| cspresdet50 | 43.2 | TBD | N/A | N/A | 24.3 |
121128
| tf_efficientdet_d2.pth | 43.4 | TBD | 42.5 | 43 | 8.10 |
129+
| cspdarkdet53m | 45.2 | TBD | N/A | N/A | 35.6 |
122130
| tf_efficientdet_d3.pth | 47.1 | TBD | 47.2 | 47.5 | 12.0 |
123131
| tf_efficientdet_d4.pth | 49.2 | TBD | 49.3 | 49.7 | 20.7 |
124132
| tf_efficientdet_d5.pth | 51.2 | TBD | 51.2 | 51.5 | 33.7 |

effdet/config/model_config.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def default_detection_model_configs():
1818

1919
h.backbone_name = 'tf_efficientnet_b1'
2020
h.backbone_args = None # FIXME sort out kwargs vs config for backbone creation
21+
h.backbone_indices = None
2122

2223
# model specific, input preprocessing parameters
2324
h.image_size = (640, 640)
@@ -167,21 +168,21 @@ def default_detection_model_configs():
167168
cspresdet50=dict(
168169
name='cspresdet50',
169170
backbone_name='cspresnet50',
170-
image_size=(640, 640),
171+
image_size=(768, 768),
171172
aspect_ratios=[1.0, 2.0, 0.5],
172173
fpn_channels=88,
173174
fpn_cell_repeats=4,
174175
box_class_repeats=3,
175176
pad_type='',
176177
act_type='leaky_relu',
177178
head_act_type='silu',
178-
downsample_type='max',
179+
downsample_type='bilinear',
179180
upsample_type='bilinear',
180181
redundant_bias=False,
181182
separable_conv=False,
182183
head_bn_level_first=True,
183184
backbone_args=dict(drop_path_rate=0.2),
184-
url='',
185+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/cspresdet50b-386da277.pth',
185186
),
186187
cspresdext50=dict(
187188
name='cspresdext50',
@@ -230,8 +231,30 @@ def default_detection_model_configs():
230231
separable_conv=False,
231232
head_bn_level_first=True,
232233
backbone_args=dict(drop_path_rate=0.2),
234+
backbone_indices=(3, 4, 5),
233235
url='',
234236
),
237+
cspdarkdet53m=dict(
238+
name='cspdarkdet53m',
239+
backbone_name='cspdarknet53',
240+
image_size=(768, 768),
241+
aspect_ratios=[1.0, 2.0, 0.5],
242+
fpn_channels=96,
243+
fpn_cell_repeats=4,
244+
box_class_repeats=3,
245+
pad_type='',
246+
fpn_name='qufpn_fa',
247+
act_type='leaky_relu',
248+
head_act_type='mish',
249+
downsample_type='bilinear',
250+
upsample_type='bilinear',
251+
redundant_bias=False,
252+
separable_conv=False,
253+
head_bn_level_first=True,
254+
backbone_args=dict(drop_path_rate=0.2),
255+
backbone_indices=(3, 4, 5),
256+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/cspdarkdet53m-79062b2d.pth',
257+
),
235258
mixdet_m=dict(
236259
name='mixdet_m',
237260
backbone_name='mixnet_m',
@@ -328,10 +351,12 @@ def default_detection_model_configs():
328351
box_class_repeats=3,
329352
pad_type='',
330353
fpn_name='qufpn_fa', # quad-fpn + fast attn experiment
354+
downsample_type='bilinear',
355+
upsample_type='bilinear',
331356
redundant_bias=False,
332357
head_bn_level_first=True,
333358
backbone_args=dict(drop_path_rate=0.2),
334-
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_q1-b238aba5.pth',
359+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_q1b-d0612140.pth',
335360
),
336361
efficientdet_q2=dict(
337362
name='efficientdet_q2',

effdet/efficientdet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,8 @@ def __init__(self, config, pretrained_backbone=True, alternate_init=False):
557557
self.config = config
558558
set_config_readonly(self.config)
559559
self.backbone = create_model(
560-
config.backbone_name, features_only=True, out_indices=(2, 3, 4),
560+
config.backbone_name, features_only=True,
561+
out_indices=self.config.backbone_indices or (2, 3, 4),
561562
pretrained=pretrained_backbone, **config.backbone_args)
562563
feature_info = get_feature_info(self.backbone)
563564
self.fpn = BiFpn(self.config, feature_info)

train.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from effdet.data import resolve_input_config, SkipSubset
4141
from effdet.anchors import Anchors, AnchorLabeler
4242
from timm.models import resume_checkpoint, load_checkpoint
43+
from timm.models.layers import set_layer_config
4344
from timm.utils import *
4445
from timm.optim import create_optimizer
4546
from timm.scheduler import create_scheduler
@@ -267,20 +268,21 @@ def main():
267268

268269
torch.manual_seed(args.seed + args.rank)
269270

270-
model = create_model(
271-
args.model,
272-
bench_task='train',
273-
num_classes=args.num_classes,
274-
pretrained=args.pretrained,
275-
pretrained_backbone=args.pretrained_backbone,
276-
redundant_bias=args.redundant_bias,
277-
label_smoothing=args.smoothing,
278-
legacy_focal=args.legacy_focal,
279-
jit_loss=args.jit_loss,
280-
soft_nms=args.soft_nms,
281-
bench_labeler=args.bench_labeler,
282-
checkpoint_path=args.initial_checkpoint,
283-
)
271+
with set_layer_config(scriptable=args.torchscript):
272+
model = create_model(
273+
args.model,
274+
bench_task='train',
275+
num_classes=args.num_classes,
276+
pretrained=args.pretrained,
277+
pretrained_backbone=args.pretrained_backbone,
278+
redundant_bias=args.redundant_bias,
279+
label_smoothing=args.smoothing,
280+
legacy_focal=args.legacy_focal,
281+
jit_loss=args.jit_loss,
282+
soft_nms=args.soft_nms,
283+
bench_labeler=args.bench_labeler,
284+
checkpoint_path=args.initial_checkpoint,
285+
)
284286
model_config = model.config # grab before we obscure with DP/DDP wrappers
285287

286288
if args.local_rank == 0:

validate.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111

1212
from effdet import create_model, create_evaluator, create_dataset, create_loader
1313
from effdet.data import resolve_input_config
14-
from effdet.evaluator import CocoEvaluator, PascalEvaluator
1514
from timm.utils import AverageMeter, setup_default_logging
16-
15+
from timm.models.layers import set_layer_config
1716

1817
has_apex = False
1918
try:
@@ -107,16 +106,17 @@ def validate(args):
107106
args.prefetcher = not args.no_prefetcher
108107

109108
# create model
110-
bench = create_model(
111-
args.model,
112-
bench_task='predict',
113-
num_classes=args.num_classes,
114-
pretrained=args.pretrained,
115-
redundant_bias=args.redundant_bias,
116-
soft_nms=args.soft_nms,
117-
checkpoint_path=args.checkpoint,
118-
checkpoint_ema=args.use_ema,
119-
)
109+
with set_layer_config(scriptable=args.torchscript):
110+
bench = create_model(
111+
args.model,
112+
bench_task='predict',
113+
num_classes=args.num_classes,
114+
pretrained=args.pretrained,
115+
redundant_bias=args.redundant_bias,
116+
soft_nms=args.soft_nms,
117+
checkpoint_path=args.checkpoint,
118+
checkpoint_ema=args.use_ema,
119+
)
120120
model_config = bench.config
121121

122122
param_count = sum([m.numel() for m in bench.parameters()])

0 commit comments

Comments
 (0)