Skip to content

Commit 9370d27

Browse files
committed
Put some post-proc/nms constants (MAX_DET..., soft-nms) as config options. Soft-nms via cmd line. Fix #120 and comment typo for fix #150.
1 parent 52279b4 commit 9370d27

File tree

6 files changed

+40
-22
lines changed

6 files changed

+40
-22
lines changed

effdet/anchors.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,6 @@
4343
# The score for a dummy detection
4444
_DUMMY_DETECTION_SCORE = -1e5
4545

46-
# The maximum number of (anchor,class) pairs to keep for non-max suppression.
47-
MAX_DETECTION_POINTS = 5000
48-
49-
# The maximum number of detections per image.
50-
MAX_DETECTIONS_PER_IMAGE = 100
51-
5246

5347
def decode_box_outputs(rel_codes, anchors, output_xyxy: bool=False):
5448
"""Transforms relative regression coordinates to absolute positions.
@@ -97,17 +91,17 @@ def clip_boxes_xyxy(boxes: torch.Tensor, size: torch.Tensor):
9791
def generate_detections(
9892
cls_outputs, box_outputs, anchor_boxes, indices, classes,
9993
img_scale: Optional[torch.Tensor], img_size: Optional[torch.Tensor],
100-
max_det_per_image: int = MAX_DETECTIONS_PER_IMAGE, soft_nms: bool = False):
94+
max_det_per_image: int = 100, soft_nms: bool = False):
10195
"""Generates detections with RetinaNet model outputs and anchors.
10296
10397
Args:
10498
cls_outputs: a torch tensor with shape [N, 1], which has the highest class
10599
scores on all feature levels. The N is the number of selected
106-
top-K total anchors on all levels. (k being MAX_DETECTION_POINTS)
100+
top-K total anchors on all levels.
107101
108102
box_outputs: a torch tensor with shape [N, 4], which stacks box regression
109103
outputs on all feature levels. The N is the number of selected top-k
110-
total anchors on all levels. (k being MAX_DETECTION_POINTS)
104+
total anchors on all levels.
111105
112106
anchor_boxes: a torch tensor with shape [N, 4], which stacks anchors on all
113107
feature levels. The N is the number of selected top-k total anchors on all levels.
@@ -124,7 +118,7 @@ def generate_detections(
124118
max_det_per_image: an int constant, added as argument to make torchscript happy
125119
126120
Returns:
127-
detections: detection results in a tensor with shape [MAX_DETECTION_POINTS, 6],
121+
detections: detection results in a tensor with shape [max_det_per_image, 6],
128122
each row representing [x_min, y_min, x_max, y_max, score, class]
129123
"""
130124
assert box_outputs.shape[-1] == 4
@@ -147,7 +141,7 @@ def generate_detections(
147141
else:
148142
top_detection_idx = batched_nms(boxes, scores, classes, iou_threshold=0.5)
149143

150-
# keep only topk scoring predictions
144+
# keep only top max_det_per_image scoring predictions
151145
top_detection_idx = top_detection_idx[:max_det_per_image]
152146
boxes = boxes[top_detection_idx]
153147
scores = scores[top_detection_idx, None]
@@ -159,7 +153,7 @@ def generate_detections(
159153
# FIXME add option to convert boxes back to yxyx? Otherwise must be handled downstream if
160154
# that is the preferred output format.
161155

162-
# stack em and pad out to MAX_DETECTIONS_PER_IMAGE if necessary
156+
# stack em and pad out to max_det_per_image if necessary
163157
num_det = len(top_detection_idx)
164158
detections = torch.cat([boxes, scores, classes.float()], dim=1)
165159
if num_det < max_det_per_image:

effdet/bench.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Optional, Dict, List
66
import torch
77
import torch.nn as nn
8-
from .anchors import Anchors, AnchorLabeler, generate_detections, MAX_DETECTION_POINTS
8+
from .anchors import Anchors, AnchorLabeler, generate_detections
99
from .loss import DetectionLoss
1010

1111

@@ -14,7 +14,7 @@ def _post_process(
1414
box_outputs: List[torch.Tensor],
1515
num_levels: int,
1616
num_classes: int,
17-
max_detection_points: int = MAX_DETECTION_POINTS,
17+
max_detection_points: int = 5000,
1818
):
1919
"""Selects top-k predictions.
2020
@@ -59,14 +59,19 @@ def _post_process(
5959
@torch.jit.script
6060
def _batch_detection(
6161
batch_size: int, class_out, box_out, anchor_boxes, indices, classes,
62-
img_scale: Optional[torch.Tensor] = None, img_size: Optional[torch.Tensor] = None):
62+
img_scale: Optional[torch.Tensor] = None,
63+
img_size: Optional[torch.Tensor] = None,
64+
max_det_per_image: int = 100,
65+
soft_nms: bool = False,
66+
):
6367
batch_detections = []
6468
# FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome
6569
for i in range(batch_size):
6670
img_scale_i = None if img_scale is None else img_scale[i]
6771
img_size_i = None if img_size is None else img_size[i]
6872
detections = generate_detections(
69-
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale_i, img_size_i)
73+
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i],
74+
img_scale_i, img_size_i, max_det_per_image=max_det_per_image, soft_nms=soft_nms)
7075
batch_detections.append(detections)
7176
return torch.stack(batch_detections, dim=0)
7277

@@ -79,17 +84,23 @@ def __init__(self, model):
7984
self.num_levels = model.config.num_levels
8085
self.num_classes = model.config.num_classes
8186
self.anchors = Anchors.from_config(model.config)
87+
self.max_detection_points = model.config.max_detection_points
88+
self.max_det_per_image = model.config.max_det_per_image
89+
self.soft_nms = model.config.soft_nms
8290

8391
def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None):
8492
class_out, box_out = self.model(x)
8593
class_out, box_out, indices, classes = _post_process(
86-
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
94+
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes,
95+
max_detection_points=self.max_detection_points)
8796
if img_info is None:
8897
img_scale, img_size = None, None
8998
else:
9099
img_scale, img_size = img_info['img_scale'], img_info['img_size']
91100
return _batch_detection(
92-
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scale, img_size)
101+
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes,
102+
img_scale, img_size, max_det_per_image=self.max_det_per_image, soft_nms=self.soft_nms
103+
)
93104

94105

95106
class DetBenchTrain(nn.Module):
@@ -100,6 +111,9 @@ def __init__(self, model, create_labeler=True):
100111
self.num_levels = model.config.num_levels
101112
self.num_classes = model.config.num_classes
102113
self.anchors = Anchors.from_config(model.config)
114+
self.max_detection_points = model.config.max_detection_points
115+
self.max_det_per_image = model.config.max_det_per_image
116+
self.soft_nms = model.config.soft_nms
103117
self.anchor_labeler = None
104118
if create_labeler:
105119
self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
@@ -122,10 +136,12 @@ def forward(self, x, target: Dict[str, torch.Tensor]):
122136
if not self.training:
123137
# if eval mode, output detections for evaluation
124138
class_out_pp, box_out_pp, indices, classes = _post_process(
125-
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes)
139+
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes,
140+
max_detection_points=self.max_detection_points)
126141
output['detections'] = _batch_detection(
127142
x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes,
128-
target['img_scale'], target['img_size'])
143+
target['img_scale'], target['img_size'],
144+
max_det_per_image=self.max_det_per_image, soft_nms=self.soft_nms)
129145
return output
130146

131147

effdet/config/model_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def default_detection_model_configs():
7070
h.delta = 0.1
7171
h.box_loss_weight = 50.0
7272

73+
# nms
74+
h.soft_nms = False # use soft-nms, this is incredibly slow
75+
h.max_detection_points = 5000 # max detections for post process, input to NMS
76+
h.max_det_per_image = 100 # max detections per image limit, output of NMS
77+
7378
return h
7479

7580

effdet/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def create_model_from_config(
2323
pretrained_backbone = False # no point in loading backbone weights
2424

2525
# Config overrides, override some config values via kwargs.
26-
overrides = ('redundant_bias', 'label_smoothing', 'legacy_focal', 'jit_loss')
26+
overrides = ('redundant_bias', 'label_smoothing', 'legacy_focal', 'jit_loss', 'soft_nms')
2727
for ov in overrides:
2828
value = kwargs.pop(ov, None)
2929
if value is not None:

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
parser.add_argument('--model', default='tf_efficientdet_d1', type=str, metavar='MODEL',
6464
help='Name of model to train (default: "tf_efficientdet_d1"')
6565
add_bool_arg(parser, 'redundant-bias', default=None, help='override model config for redundant bias')
66-
parser.set_defaults(redundant_bias=None)
66+
add_bool_arg(parser, 'soft-nms', default=None, help='override model config for soft-nms')
6767
parser.add_argument('--val-skip', type=int, default=0, metavar='N',
6868
help='Skip every N validation samples.')
6969
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
@@ -277,6 +277,7 @@ def main():
277277
label_smoothing=args.smoothing,
278278
legacy_focal=args.legacy_focal,
279279
jit_loss=args.jit_loss,
280+
soft_nms=args.soft_nms,
280281
bench_labeler=args.bench_labeler,
281282
checkpoint_path=args.initial_checkpoint,
282283
)

validate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def add_bool_arg(parser, name, default=False, help=''): # FIXME move to utils
5151
help='model architecture (default: tf_efficientdet_d1)')
5252
add_bool_arg(parser, 'redundant-bias', default=None,
5353
help='override model config for redundant bias layers')
54+
add_bool_arg(parser, 'soft-nms', default=None, help='override model config for soft-nms')
5455
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
5556
help='Override num_classes in model config if set. For fine-tuning from pretrained.')
5657
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
@@ -112,6 +113,7 @@ def validate(args):
112113
num_classes=args.num_classes,
113114
pretrained=args.pretrained,
114115
redundant_bias=args.redundant_bias,
116+
soft_nms=args.soft_nms,
115117
checkpoint_path=args.checkpoint,
116118
checkpoint_ema=args.use_ema,
117119
)

0 commit comments

Comments
 (0)