Skip to content

Commit 3980d2a

Browse files
Pooya Davoodisaberkun
authored andcommitted
Add Combined NMS (#6138)
* Updating python API to use CombinedNonMaxSuppresion TF operator 1. Adds a unit test to test post_processing python API 2. Currently sets clip_window to None as the kernel uses the default clip_window of [0,0,1,1] 3. Added use_static_shapes to the API. In old API if use_static_shapes is true, then it pads/clips outputs to max_total_size, if specified. If not specified, it pads to num_classes*max_size_per_class. If use_static_shapes is false, it always pads/clips to max_total_size. Update unit test to account for clipped bouding boxes Changed the name to CombinedNonMaxSuppression based on feedback from Google Added additional parameters to combinedNMS python function. They are currently unused and required for networks like FasterRCNN and MaskRCNN * Delete selected_indices from API Because it was removed from CombinedNMS recently in the PR. * Improve doc of function combined_non_max_suppression * Enable CombinedNonMaxSuppression for first_stage_nms * fix bug * Ensure agnostic_nms is not used with combined_nms Remove redundant arguments from combined_nms * Fix pylint * Add checks for unsupported args * Fix pylint * Move combined_non_max_suppression to batch_multiclass_non_max_suppression Also rename combined_nms to use_combined_nms * Delete combined_nms for first_stage_nms because it does not work * Revert "Delete combined_nms for first_stage_nms because it does not work" This reverts commit 2a3cc51. * Use nmsed_additional_fields.get to avoid error * Merge combined_non_max_suppression with main nms function * Rename combined_nms for first stage nms * Improve docs * Use assertListEqual for numpy arrays * Fix pylint errors * End comments with period
1 parent f5e2211 commit 3980d2a

File tree

8 files changed

+130
-8
lines changed

8 files changed

+130
-8
lines changed

research/object_detection/builders/model_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
484484
iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
485485
max_size_per_class=frcnn_config.first_stage_max_proposals,
486486
max_total_size=frcnn_config.first_stage_max_proposals,
487-
use_static_shapes=use_static_shapes)
487+
use_static_shapes=use_static_shapes,
488+
use_combined_nms=frcnn_config.use_combined_nms_in_first_stage)
488489
first_stage_loc_loss_weight = (
489490
frcnn_config.first_stage_localization_loss_weight)
490491
first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

research/object_detection/builders/post_processing_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def _build_non_max_suppressor(nms_config):
8888
'max_total_detections.')
8989
if nms_config.soft_nms_sigma < 0.0:
9090
raise ValueError('soft_nms_sigma should be non-negative.')
91+
if nms_config.use_combined_nms and nms_config.use_class_agnostic_nms:
92+
raise ValueError('combined_nms does not support class_agnostic_nms')
93+
9194
non_max_suppressor_fn = functools.partial(
9295
post_processing.batch_multiclass_non_max_suppression,
9396
score_thresh=nms_config.score_threshold,
@@ -97,7 +100,8 @@ def _build_non_max_suppressor(nms_config):
97100
use_static_shapes=nms_config.use_static_shapes,
98101
use_class_agnostic_nms=nms_config.use_class_agnostic_nms,
99102
max_classes_per_detection=nms_config.max_classes_per_detection,
100-
soft_nms_sigma=nms_config.soft_nms_sigma)
103+
soft_nms_sigma=nms_config.soft_nms_sigma,
104+
use_combined_nms=nms_config.use_combined_nms)
101105
return non_max_suppressor_fn
102106

103107

research/object_detection/core/batch_multiclass_nms_test.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,55 @@ def test_batch_multiclass_nms_with_additional_fields_and_num_valid_boxes(
663663
exp_nms_additional_fields[key])
664664
self.assertAllClose(num_detections, [1, 1])
665665

666-
# TODO(bhattad): Remove conditional after CMLE moves to TF 1.9
666+
def test_combined_nms_with_batch_size_2(self):
667+
"""Test use_combined_nms."""
668+
boxes = tf.constant([[[[0, 0, 0.1, 0.1], [0, 0, 0.1, 0.1]],
669+
[[0, 0.01, 1, 0.11], [0, 0.6, 0.1, 0.7]],
670+
[[0, -0.01, 0.1, 0.09], [0, -0.1, 0.1, 0.09]],
671+
[[0, 0.11, 0.1, 0.2], [0, 0.11, 0.1, 0.2]]],
672+
[[[0, 0, 0.2, 0.2], [0, 0, 0.2, 0.2]],
673+
[[0, 0.02, 0.2, 0.22], [0, 0.02, 0.2, 0.22]],
674+
[[0, -0.02, 0.2, 0.19], [0, -0.02, 0.2, 0.19]],
675+
[[0, 0.21, 0.2, 0.3], [0, 0.21, 0.2, 0.3]]]],
676+
tf.float32)
677+
scores = tf.constant([[[.1, 0.9], [.75, 0.8],
678+
[.6, 0.3], [0.95, 0.1]],
679+
[[.1, 0.9], [.75, 0.8],
680+
[.6, .3], [.95, .1]]])
681+
score_thresh = 0.1
682+
iou_thresh = .5
683+
max_output_size = 3
684+
685+
exp_nms_corners = np.array([[[0, 0.11, 0.1, 0.2],
686+
[0, 0, 0.1, 0.1],
687+
[0, 0.6, 0.1, 0.7]],
688+
[[0, 0.21, 0.2, 0.3],
689+
[0, 0, 0.2, 0.2],
690+
[0, 0.02, 0.2, 0.22]]])
691+
exp_nms_scores = np.array([[.95, .9, 0.8],
692+
[.95, .9, .75]])
693+
exp_nms_classes = np.array([[0, 1, 1],
694+
[0, 1, 0]])
695+
696+
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
697+
nmsed_additional_fields, num_detections
698+
) = post_processing.batch_multiclass_non_max_suppression(
699+
boxes, scores, score_thresh, iou_thresh,
700+
max_size_per_class=max_output_size, max_total_size=max_output_size,
701+
use_static_shapes=True,
702+
use_combined_nms=True)
703+
704+
self.assertIsNone(nmsed_masks)
705+
self.assertIsNone(nmsed_additional_fields)
706+
707+
with self.test_session() as sess:
708+
(nmsed_boxes, nmsed_scores, nmsed_classes,
709+
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
710+
num_detections])
711+
self.assertAllClose(nmsed_boxes, exp_nms_corners)
712+
self.assertAllClose(nmsed_scores, exp_nms_scores)
713+
self.assertAllClose(nmsed_classes, exp_nms_classes)
714+
self.assertListEqual(num_detections.tolist(), [3, 3])
667715

668716
if __name__ == '__main__':
669717
tf.test.main()

research/object_detection/core/post_processing.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,8 @@ def batch_multiclass_non_max_suppression(boxes,
820820
use_static_shapes=False,
821821
parallel_iterations=32,
822822
use_class_agnostic_nms=False,
823-
max_classes_per_detection=1):
823+
max_classes_per_detection=1,
824+
use_combined_nms=False):
824825
"""Multi-class version of non maximum suppression that operates on a batch.
825826
826827
This op is similar to `multiclass_non_max_suppression` but operates on a batch
@@ -866,14 +867,28 @@ def batch_multiclass_non_max_suppression(boxes,
866867
False.
867868
scope: tf scope name.
868869
use_static_shapes: If true, the output nmsed boxes are padded to be of
869-
length `max_size_per_class` and it doesn't clip boxes to max_total_size.
870+
length `minimum(max_total_size, max_size_per_class*num_classes)`.
871+
If false, they are padded to be of length `max_total_size`.
870872
Defaults to false.
871873
parallel_iterations: (optional) number of batch items to process in
872874
parallel.
873875
use_class_agnostic_nms: If true, this uses class-agnostic non max
874876
suppression
875877
max_classes_per_detection: Maximum number of retained classes per detection
876878
box in class-agnostic NMS.
879+
use_combined_nms: If true, it uses tf.image.combined_non_max_suppression (
880+
multi-class version of NMS that operates on a batch).
881+
It greedily selects a subset of detection bounding boxes, pruning away
882+
boxes that have high IOU (intersection over union) overlap (> thresh) with
883+
already selected boxes. It operates independently for each batch.
884+
Within each batch, it operates independently for each class for which
885+
scores are provided (via the scores field of the input box_list),
886+
pruning boxes with score less than a provided threshold prior to applying
887+
NMS. This operation is performed on *all* batches and *all* classes
888+
in the batch, therefore any background classes should be removed prior to
889+
calling this function.
890+
Masks and additional fields are not supported.
891+
See argument checks in the code below for unsupported arguments.
877892
878893
Returns:
879894
'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
@@ -899,11 +914,57 @@ def batch_multiclass_non_max_suppression(boxes,
899914
ValueError: if `q` in boxes.shape is not 1 or not equal to number of
900915
classes as inferred from scores.shape.
901916
"""
917+
if use_combined_nms:
918+
if change_coordinate_frame:
919+
raise ValueError(
920+
'change_coordinate_frame (normalizing coordinates'
921+
' relative to clip_window) is not supported by combined_nms.')
922+
if num_valid_boxes is not None:
923+
raise ValueError('num_valid_boxes is not supported by combined_nms.')
924+
if masks is not None:
925+
raise ValueError('masks is not supported by combined_nms.')
926+
if soft_nms_sigma != 0.0:
927+
raise ValueError('Soft NMS is not supported by combined_nms.')
928+
if use_class_agnostic_nms:
929+
raise ValueError('class-agnostic NMS is not supported by combined_nms.')
930+
if clip_window is not None:
931+
tf.compat.v1.logging.warning(
932+
'clip_window is not supported by combined_nms unless it is'
933+
' [0. 0. 1. 1.] for each image.')
934+
if additional_fields is not None:
935+
tf.compat.v1.logging.warning(
936+
'additional_fields is not supported by combined_nms.')
937+
if parallel_iterations != 32:
938+
tf.compat.v1.logging.warning(
939+
'Number of batch items to be processed in parallel is'
940+
' not configurable by combined_nms.')
941+
if max_classes_per_detection > 1:
942+
tf.compat.v1.logging.warning(
943+
'max_classes_per_detection is not configurable by combined_nms.')
944+
945+
with tf.name_scope(scope, 'CombinedNonMaxSuppression'):
946+
(batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
947+
batch_num_detections) = tf.image.combined_non_max_suppression(
948+
boxes=boxes,
949+
scores=scores,
950+
max_output_size_per_class=max_size_per_class,
951+
max_total_size=max_total_size,
952+
iou_threshold=iou_thresh,
953+
score_threshold=score_thresh,
954+
pad_per_class=use_static_shapes)
955+
# Not supported by combined_non_max_suppression.
956+
batch_nmsed_masks = None
957+
# Not supported by combined_non_max_suppression.
958+
batch_nmsed_additional_fields = None
959+
return (batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
960+
batch_nmsed_masks, batch_nmsed_additional_fields,
961+
batch_num_detections)
962+
902963
q = shape_utils.get_dim_as_int(boxes.shape[2])
903964
num_classes = shape_utils.get_dim_as_int(scores.shape[2])
904965
if q != 1 and q != num_classes:
905966
raise ValueError('third dimension of boxes must be either 1 or equal '
906-
'to the third dimension of scores')
967+
'to the third dimension of scores.')
907968
if change_coordinate_frame and clip_window is None:
908969
raise ValueError('if change_coordinate_frame is True, then a clip_window'
909970
'must be specified.')

research/object_detection/meta_architectures/faster_rcnn_meta_arch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1621,7 +1621,8 @@ def normalize_boxes(args):
16211621
normalize_boxes,
16221622
elems=[raw_proposal_boxes, image_shapes],
16231623
dtype=tf.float32)
1624-
proposal_multiclass_scores = nmsed_additional_fields['multiclass_scores']
1624+
proposal_multiclass_scores = nmsed_additional_fields.get(
1625+
'multiclass_scores') if nmsed_additional_fields else None,
16251626
return (normalized_proposal_boxes, proposal_scores,
16261627
proposal_multiclass_scores, num_proposals,
16271628
raw_normalized_proposal_boxes, rpn_objectness_softmax)

research/object_detection/meta_architectures/ssd_meta_arch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,8 @@ def postprocess(self, prediction_dict, true_image_shapes):
746746
fields.DetectionResultFields.detection_classes:
747747
nmsed_classes,
748748
fields.DetectionResultFields.detection_multiclass_scores:
749-
nmsed_additional_fields['multiclass_scores'],
749+
nmsed_additional_fields.get(
750+
'multiclass_scores') if nmsed_additional_fields else None,
750751
fields.DetectionResultFields.num_detections:
751752
tf.cast(num_detections, dtype=tf.float32),
752753
fields.DetectionResultFields.raw_detection_boxes:

research/object_detection/protos/faster_rcnn.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ message FasterRcnn {
168168
// If True, uses implementation of ops with static shape guarantees when
169169
// running evaluation (specifically not is_training if False).
170170
optional bool use_static_shapes_for_eval = 37 [default = false];
171+
172+
// Whether to use tf.image.combined_non_max_suppression.
173+
optional bool use_combined_nms_in_first_stage = 38 [default=false];
171174
}
172175

173176

research/object_detection/protos/post_processing.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ message BatchNonMaxSuppression {
3939

4040
// Soft NMS sigma parameter; Bodla et al, https://arxiv.org/abs/1704.04503)
4141
optional float soft_nms_sigma = 9 [default = 0.0];
42+
43+
// Whether to use tf.image.combined_non_max_suppression.
44+
optional bool use_combined_nms = 10 [default = false];
4245
}
4346

4447
// Configuration proto for post-processing predicted boxes and

0 commit comments

Comments
 (0)