Skip to content

Commit f4315ee

Browse files
tensorflower-gardenerfyangf
authored andcommitted
Internal change
PiperOrigin-RevId: 505144061
1 parent 695299a commit f4315ee

File tree

2 files changed

+225
-21
lines changed

2 files changed

+225
-21
lines changed

official/vision/modeling/layers/detection_generator.py

Lines changed: 201 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,130 @@ def _select_top_k_scores(scores_in: tf.Tensor, pre_nms_num_detections: int):
328328
)
329329

330330

331-
def _generate_detections_v2(
331+
def _generate_detections_v2_class_agnostic(
332332
boxes: tf.Tensor,
333333
scores: tf.Tensor,
334334
pre_nms_top_k: int = 5000,
335335
pre_nms_score_threshold: float = 0.05,
336336
nms_iou_threshold: float = 0.5,
337-
max_num_detections: int = 100,
337+
max_num_detections: int = 100
338338
):
339-
"""Generates the final detections given the model outputs.
339+
"""Generates the final detections by applying class-agnostic NMS.
340340
341-
This implementation unrolls classes dimension while using the tf.while_loop
342-
to implement the batched NMS, so that it can be parallelized at the batch
343-
dimension. It should give better performance comparing to v1 implementation.
344-
It is TPU compatible.
341+
Args:
342+
boxes: A `tf.Tensor` with shape `[batch_size, N, num_classes, 4]` or
343+
`[batch_size, N, 1, 4]`, which box predictions on all feature levels. The
344+
N is the number of total anchors on all levels.
345+
scores: A `tf.Tensor` with shape `[batch_size, N, num_classes]`, which
346+
stacks class probability on all feature levels. The N is the number of
347+
total anchors on all levels. The num_classes is the number of classes
348+
predicted by the model. Note that the class_outputs here is the raw score.
349+
pre_nms_top_k: An `int` number of top candidate detections per class before
350+
NMS.
351+
pre_nms_score_threshold: A `float` representing the threshold for deciding
352+
when to remove boxes based on score.
353+
nms_iou_threshold: A `float` representing the threshold for deciding whether
354+
boxes overlap too much with respect to IOU.
355+
max_num_detections: A `scalar` representing maximum number of boxes retained
356+
over all classes.
357+
358+
Returns:
359+
nms_boxes: A `float` tf.Tensor of shape [batch_size, max_num_detections, 4]
360+
representing top detected boxes in [y1, x1, y2, x2].
361+
nms_scores: A `float` tf.Tensor of shape [batch_size, max_num_detections]
362+
representing sorted confidence scores for detected boxes. The values are
363+
between [0, 1].
364+
nms_classes: An `int` tf.Tensor of shape [batch_size, max_num_detections]
365+
representing classes for detected boxes.
366+
valid_detections: An `int` tf.Tensor of shape [batch_size] only the top
367+
`valid_detections` boxes are valid detections.
368+
"""
369+
with tf.name_scope('generate_detections_class_agnostic'):
370+
nmsed_boxes = []
371+
nmsed_classes = []
372+
nmsed_scores = []
373+
valid_detections = []
374+
batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
375+
if batch_size is None:
376+
batch_size = tf.shape(boxes)[0]
377+
_, total_anchors, _ = scores.get_shape().as_list()
378+
379+
# Keeps only the class with highest score for each predicted box.
380+
scores_condensed, classes_ids = tf.nn.top_k(
381+
scores, k=1, sorted=True
382+
)
383+
scores_condensed = tf.squeeze(scores_condensed, axis=[2])
384+
if num_classes_for_box > 1:
385+
boxes = tf.gather(boxes, classes_ids, axis=2, batch_dims=2)
386+
boxes_condensed = tf.squeeze(boxes, axis=[2])
387+
classes_condensed = tf.squeeze(classes_ids, axis=[2])
388+
389+
# Selects top pre_nms_num scores and indices before NMS.
390+
num_anchors_filtered = min(total_anchors, pre_nms_top_k)
391+
scores_filtered, indices_filtered = tf.nn.top_k(
392+
scores_condensed, k=num_anchors_filtered, sorted=True
393+
)
394+
classes_filtered = tf.gather(
395+
classes_condensed, indices_filtered, axis=1, batch_dims=1
396+
)
397+
boxes_filtered = tf.gather(
398+
boxes_condensed, indices_filtered, axis=1, batch_dims=1
399+
)
400+
401+
tf.ensure_shape(boxes_filtered, [None, num_anchors_filtered, 4])
402+
tf.ensure_shape(classes_filtered, [None, num_anchors_filtered])
403+
tf.ensure_shape(scores_filtered, [None, num_anchors_filtered])
404+
boxes_filtered = tf.cast(
405+
boxes_filtered, tf.float32
406+
)
407+
scores_filtered = tf.cast(
408+
scores_filtered, tf.float32
409+
)
410+
# Apply class-agnostic NMS on boxes.
411+
(nmsed_indices_padded, valid_detections) = (
412+
tf.image.non_max_suppression_padded(
413+
boxes=boxes_filtered,
414+
scores=scores_filtered,
415+
max_output_size=max_num_detections,
416+
iou_threshold=nms_iou_threshold,
417+
pad_to_max_output_size=True,
418+
score_threshold=pre_nms_score_threshold,
419+
sorted_input=True,
420+
name='nms_detections'
421+
)
422+
)
423+
nmsed_boxes = tf.gather(
424+
boxes_filtered, nmsed_indices_padded, batch_dims=1, axis=1
425+
)
426+
nmsed_scores = tf.gather(
427+
scores_filtered, nmsed_indices_padded, batch_dims=1, axis=1
428+
)
429+
nmsed_classes = tf.gather(
430+
classes_filtered, nmsed_indices_padded, batch_dims=1, axis=1
431+
)
432+
433+
# Sets the padded boxes, scores, and classes to 0.
434+
padding_mask = tf.reshape(
435+
tf.range(max_num_detections), [1, -1]
436+
) < tf.reshape(valid_detections, [-1, 1])
437+
nmsed_boxes = nmsed_boxes * tf.cast(
438+
tf.expand_dims(padding_mask, axis=2), nmsed_boxes.dtype
439+
)
440+
nmsed_scores = nmsed_scores * tf.cast(padding_mask, nmsed_scores.dtype)
441+
nmsed_classes = nmsed_classes * tf.cast(padding_mask, nmsed_classes.dtype)
442+
443+
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
444+
445+
446+
def _generate_detections_v2_class_aware(
447+
boxes: tf.Tensor,
448+
scores: tf.Tensor,
449+
pre_nms_top_k: int = 5000,
450+
pre_nms_score_threshold: float = 0.05,
451+
nms_iou_threshold: float = 0.5,
452+
max_num_detections: int = 100,
453+
):
454+
"""Generates the final detections by using class-aware NMS.
345455
346456
Args:
347457
boxes: A `tf.Tensor` with shape `[batch_size, N, num_classes, 4]` or
@@ -419,6 +529,72 @@ def _generate_detections_v2(
419529
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
420530

421531

532+
def _generate_detections_v2(
533+
boxes: tf.Tensor,
534+
scores: tf.Tensor,
535+
pre_nms_top_k: int = 5000,
536+
pre_nms_score_threshold: float = 0.05,
537+
nms_iou_threshold: float = 0.5,
538+
max_num_detections: int = 100,
539+
use_class_agnostic_nms: Optional[bool] = None,
540+
):
541+
"""Generates the final detections given the model outputs.
542+
543+
This implementation unrolls classes dimension while using the tf.while_loop
544+
to implement the batched NMS, so that it can be parallelized at the batch
545+
dimension. It should give better performance comparing to v1 implementation.
546+
It is TPU compatible.
547+
548+
Args:
549+
boxes: A `tf.Tensor` with shape `[batch_size, N, num_classes, 4]` or
550+
`[batch_size, N, 1, 4]`, which box predictions on all feature levels. The
551+
N is the number of total anchors on all levels.
552+
scores: A `tf.Tensor` with shape `[batch_size, N, num_classes]`, which
553+
stacks class probability on all feature levels. The N is the number of
554+
total anchors on all levels. The num_classes is the number of classes
555+
predicted by the model. Note that the class_outputs here is the raw score.
556+
pre_nms_top_k: An `int` number of top candidate detections per class before
557+
NMS.
558+
pre_nms_score_threshold: A `float` representing the threshold for deciding
559+
when to remove boxes based on score.
560+
nms_iou_threshold: A `float` representing the threshold for deciding whether
561+
boxes overlap too much with respect to IOU.
562+
max_num_detections: A `scalar` representing maximum number of boxes retained
563+
over all classes.
564+
use_class_agnostic_nms: A `bool` of whether non max suppression is operated
565+
on all the boxes using max scores across all classes.
566+
567+
Returns:
568+
nms_boxes: A `float` tf.Tensor of shape [batch_size, max_num_detections, 4]
569+
representing top detected boxes in [y1, x1, y2, x2].
570+
nms_scores: A `float` tf.Tensor of shape [batch_size, max_num_detections]
571+
representing sorted confidence scores for detected boxes. The values are
572+
between [0, 1].
573+
nms_classes: An `int` tf.Tensor of shape [batch_size, max_num_detections]
574+
representing classes for detected boxes.
575+
valid_detections: An `int` tf.Tensor of shape [batch_size] only the top
576+
`valid_detections` boxes are valid detections.
577+
"""
578+
if use_class_agnostic_nms:
579+
return _generate_detections_v2_class_agnostic(
580+
boxes=boxes,
581+
scores=scores,
582+
pre_nms_top_k=pre_nms_top_k,
583+
pre_nms_score_threshold=pre_nms_score_threshold,
584+
nms_iou_threshold=nms_iou_threshold,
585+
max_num_detections=max_num_detections,
586+
)
587+
588+
return _generate_detections_v2_class_aware(
589+
boxes=boxes,
590+
scores=scores,
591+
pre_nms_top_k=pre_nms_top_k,
592+
pre_nms_score_threshold=pre_nms_score_threshold,
593+
nms_iou_threshold=nms_iou_threshold,
594+
max_num_detections=max_num_detections,
595+
)
596+
597+
422598
def _generate_detections_v3(
423599
boxes: tf.Tensor,
424600
scores: tf.Tensor,
@@ -957,6 +1133,7 @@ def __init__(
9571133
pre_nms_top_k_sharding_block: Optional[int] = None,
9581134
nms_v3_refinements: Optional[int] = None,
9591135
return_decoded: Optional[bool] = None,
1136+
use_class_agnostic_nms: Optional[bool] = None,
9601137
**kwargs,
9611138
):
9621139
"""Initializes a multi-level detection generator.
@@ -989,8 +1166,21 @@ def __init__(
9891166
if == 2, AP is reduced <0.1%, AR is reduced <1% on COCO
9901167
return_decoded: A `bool` of whether to return decoded boxes before NMS
9911168
regardless of whether `apply_nms` is True or not.
1169+
use_class_agnostic_nms: A `bool` of whether non max suppression is
1170+
operated on all the boxes using max scores across all classes.
9921171
**kwargs: Additional keyword arguments passed to Layer.
1172+
1173+
Raises:
1174+
ValueError: If `use_class_agnostic_nms` is required by `nms_version` is
1175+
not specified as `v2`.
9931176
"""
1177+
if use_class_agnostic_nms and nms_version != 'v2':
1178+
raise ValueError(
1179+
'If not using TFLite custom NMS, `use_class_agnostic_nms` can only be'
1180+
' enabled for NMS v2 for now, but NMS {} is used! If you are using'
1181+
' TFLite NMS, please configure TFLite custom NMS for class-agnostic'
1182+
' NMS.'.format(nms_version)
1183+
)
9941184
self._config_dict = {
9951185
'apply_nms': apply_nms,
9961186
'pre_nms_top_k': pre_nms_top_k,
@@ -1001,6 +1191,7 @@ def __init__(
10011191
'use_cpu_nms': use_cpu_nms,
10021192
'soft_nms_sigma': soft_nms_sigma,
10031193
'return_decoded': return_decoded,
1194+
'use_class_agnostic_nms': use_class_agnostic_nms,
10041195
}
10051196
# Don't store if were not defined
10061197
if pre_nms_top_k_sharding_block is not None:
@@ -1347,6 +1538,9 @@ def __call__(
13471538
],
13481539
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
13491540
max_num_detections=self._config_dict['max_num_detections'],
1541+
use_class_agnostic_nms=self._config_dict[
1542+
'use_class_agnostic_nms'
1543+
],
13501544
)
13511545
)
13521546
# Set `nmsed_attributes` to None for v2.

official/vision/modeling/layers/detection_generator_test.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,28 @@ class MultilevelDetectionGeneratorTest(
124124
parameterized.TestCase, tf.test.TestCase):
125125

126126
@parameterized.parameters(
127-
('batched', False, True, None, None),
128-
('batched', False, False, None, None),
129-
('v3', False, True, None, None),
130-
('v3', False, False, None, None),
131-
('v2', False, True, None, None),
132-
('v2', False, False, None, None),
133-
('v1', True, True, 0.0, None),
134-
('v1', True, False, 0.1, None),
135-
('v1', True, False, None, None),
136-
('tflite', False, False, None, True),
137-
('tflite', False, False, None, False),
127+
('batched', False, True, None, None, None),
128+
('batched', False, False, None, None, None),
129+
('v3', False, True, None, None, None),
130+
('v3', False, False, None, None, None),
131+
('v2', False, True, None, None, None),
132+
('v2', False, False, None, None, None),
133+
('v2', False, False, None, None, True),
134+
('v1', True, True, 0.0, None, None),
135+
('v1', True, False, 0.1, None, None),
136+
('v1', True, False, None, None, None),
137+
('tflite', False, False, None, True, None),
138+
('tflite', False, False, None, False, None),
138139
)
139-
def testDetectionsOutputShape(self, nms_version, has_att_heads, use_cpu_nms,
140-
soft_nms_sigma, use_regular_nms):
140+
def testDetectionsOutputShape(
141+
self,
142+
nms_version,
143+
has_att_heads,
144+
use_cpu_nms,
145+
soft_nms_sigma,
146+
use_regular_nms,
147+
use_class_agnostic_nms,
148+
):
141149
min_level = 4
142150
max_level = 6
143151
num_scales = 2
@@ -167,7 +175,8 @@ def testDetectionsOutputShape(self, nms_version, has_att_heads, use_cpu_nms,
167175
'nms_version': nms_version,
168176
'use_cpu_nms': use_cpu_nms,
169177
'soft_nms_sigma': soft_nms_sigma,
170-
'tflite_post_processing_config': tflite_post_processing_config
178+
'tflite_post_processing_config': tflite_post_processing_config,
179+
'use_class_agnostic_nms': use_class_agnostic_nms,
171180
}
172181

173182
input_anchor = anchor.build_anchor_generator(min_level, max_level,
@@ -338,6 +347,7 @@ def test_serialize_deserialize(self):
338347
'soft_nms_sigma': None,
339348
'tflite_post_processing_config': tflite_post_processing_config,
340349
'return_decoded': False,
350+
'use_class_agnostic_nms': False,
341351
}
342352
generator = detection_generator.MultilevelDetectionGenerator(**kwargs)
343353

0 commit comments

Comments
 (0)