Skip to content

Commit 0516807

Browse files
tensorflower-gardenerfyangf
authored andcommitted
Internal change
PiperOrigin-RevId: 491378667
1 parent 95f1070 commit 0516807

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

official/vision/modeling/layers/detection_generator.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
"""Contains definitions of generators to generate the final detections."""
1616
import contextlib
17-
from typing import Any, Dict, List, Optional, Mapping, Sequence
17+
from typing import Any, Dict, List, Optional, Mapping, Sequence, Tuple
1818
# Import libraries
1919
import tensorflow as tf
2020

21+
from official.projects.edgetpu.vision.modeling import custom_layers
2122
from official.vision.ops import box_ops
2223
from official.vision.ops import nms
2324
from official.vision.ops import preprocess_ops
@@ -372,6 +373,93 @@ def _generate_detections_v2(boxes: tf.Tensor,
372373
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
373374

374375

376+
def _generate_detections_v3(
377+
boxes: tf.Tensor,
378+
scores: tf.Tensor,
379+
pre_nms_score_threshold: float = 0.05,
380+
nms_iou_threshold: float = 0.5,
381+
max_num_detections: int = 100
382+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
383+
"""Generates the detections given the model outputs using NMS for EdgeTPU.
384+
385+
Args:
386+
boxes: A `tf.Tensor` with shape `[batch_size, num_classes, N, 4]` or
387+
`[batch_size, 1, N, 4]`, which box predictions on all feature levels. The
388+
N is the number of total anchors on all levels.
389+
scores: A `tf.Tensor` with shape `[batch_size, num_classes, N]`, which
390+
stacks class probability on all feature levels. The N is the number of
391+
total anchors on all levels. The num_classes is the number of classes
392+
predicted by the model. Note that the class_outputs here is the raw score.
393+
pre_nms_score_threshold: A `float` representing the threshold for deciding
394+
when to remove boxes based on score.
395+
nms_iou_threshold: A `float` representing the threshold for deciding whether
396+
boxes overlap too much with respect to IOU.
397+
max_num_detections: A `scalar` representing maximum number of boxes retained
398+
over all classes.
399+
400+
Returns:
401+
nms_boxes: A `float` tf.Tensor of shape [batch_size, max_num_detections, 4]
402+
representing top detected boxes in [y1, x1, y2, x2].
403+
nms_scores: A `float` tf.Tensor of shape [batch_size, max_num_detections]
404+
representing sorted confidence scores for detected boxes. The values are
405+
between [0, 1].
406+
nms_classes: An `int` tf.Tensor of shape [batch_size, max_num_detections]
407+
representing classes for detected boxes.
408+
valid_detections: An `int` tf.Tensor of shape [batch_size] only the top
409+
`valid_detections` boxes are valid detections.
410+
411+
Raises:
412+
ValueError if inputs shapes are not valid.
413+
"""
414+
with tf.name_scope('generate_detections'):
415+
batch_size, num_box_classes, box_locations, sides = (
416+
boxes.get_shape().as_list())
417+
if batch_size is None:
418+
batch_size = tf.shape(boxes)[0]
419+
_, num_classes, locations = scores.get_shape().as_list()
420+
if num_box_classes != 1 and num_box_classes != num_classes:
421+
raise ValueError('Boxes should have either 1 class or same as scores.')
422+
if locations != box_locations:
423+
raise ValueError('Number of locations is different.')
424+
if sides != 4:
425+
raise ValueError('Number of sides is incorrect.')
426+
# Selects pre_nms_score_threshold scores before NMS.
427+
boxes, scores = box_ops.filter_boxes_by_scores(
428+
boxes, scores, min_score_threshold=pre_nms_score_threshold)
429+
430+
# EdgeTPU-friendly class-wise NMS, -1 for invalid.
431+
indices = custom_layers.non_max_suppression_padded(
432+
boxes,
433+
scores,
434+
max_num_detections,
435+
iou_threshold=nms_iou_threshold)
436+
# Gather NMS-ed boxes and scores.
437+
safe_indices = tf.nn.relu(indices) # 0 for invalid
438+
invalid_detections = safe_indices - indices # 1 for invalid, 0 for valid
439+
valid_detections = 1.0 - invalid_detections # 0 for invalid, 1 for valid
440+
safe_indices = tf.cast(safe_indices, tf.int32)
441+
boxes = tf.expand_dims(valid_detections, -1) * tf.gather(
442+
boxes, safe_indices, axis=2, batch_dims=2)
443+
scores = valid_detections * tf.gather(
444+
scores, safe_indices, axis=2, batch_dims=2)
445+
# Compliment with class numbers.
446+
classes = tf.range(num_classes, dtype=tf.float32)
447+
classes = tf.reshape(classes, [1, num_classes, 1])
448+
classes = tf.tile(classes, [batch_size, 1, max_num_detections])
449+
# Flatten classes, locations. Class = -1 for invalid detection
450+
scores = tf.reshape(scores, [batch_size, num_classes * max_num_detections])
451+
boxes = tf.reshape(boxes, [batch_size, num_classes * max_num_detections, 4])
452+
classes = tf.reshape(valid_detections * classes - invalid_detections,
453+
[batch_size, num_classes * max_num_detections])
454+
# Filter top-k across boxes of all classes
455+
scores, indices = tf.nn.top_k(scores, k=max_num_detections, sorted=True)
456+
boxes = tf.gather(boxes, indices, batch_dims=1, axis=1)
457+
classes = tf.gather(classes, indices, batch_dims=1, axis=1)
458+
invalid_detections = tf.nn.relu(classes) - classes
459+
valid_detections = tf.reduce_sum(1. - invalid_detections, axis=1)
460+
return boxes, scores, classes, valid_detections
461+
462+
375463
def _generate_detections_batched(boxes: tf.Tensor, scores: tf.Tensor,
376464
pre_nms_score_threshold: float,
377465
nms_iou_threshold: float,
@@ -997,6 +1085,19 @@ def __call__(self,
9971085
max_num_detections=self._config_dict['max_num_detections']))
9981086
# Set `nmsed_attributes` to None for v2.
9991087
nmsed_attributes = {}
1088+
elif self._config_dict['nms_version'] == 'v3':
1089+
# TODO(tohaspiridonov): add compatible version of
1090+
# `_decode_multilevel_outputs` in cl/485381750
1091+
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
1092+
_generate_detections_v3(
1093+
tf.transpose(boxes, [0, 2, 1, 3]),
1094+
tf.transpose(scores, [0, 2, 1]),
1095+
pre_nms_score_threshold=self
1096+
._config_dict['pre_nms_score_threshold'],
1097+
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
1098+
max_num_detections=self._config_dict['max_num_detections']))
1099+
# Set `nmsed_attributes` to None for v2.
1100+
nmsed_attributes = {}
10001101
else:
10011102
raise ValueError('NMS version {} not supported.'.format(
10021103
self._config_dict['nms_version']))

official/vision/modeling/layers/detection_generator_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ class MultilevelDetectionGeneratorTest(
126126
@parameterized.parameters(
127127
('batched', False, True, None, None),
128128
('batched', False, False, None, None),
129+
('v3', False, True, None, None),
130+
('v3', False, False, None, None),
129131
('v2', False, True, None, None),
130132
('v2', False, False, None, None),
131133
('v1', True, True, 0.0, None),

0 commit comments

Comments
 (0)