Skip to content

Commit f4f994e

Browse files
tensorflower-gardenerfyangf
authored andcommitted
Internal change
PiperOrigin-RevId: 495044886
1 parent 59ddcf2 commit f4f994e

File tree

2 files changed

+150
-14
lines changed

2 files changed

+150
-14
lines changed

official/vision/modeling/layers/detection_generator.py

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ def __init__(self,
853853
use_cpu_nms: bool = False,
854854
soft_nms_sigma: Optional[float] = None,
855855
tflite_post_processing_config: Optional[Dict[str, Any]] = None,
856+
pre_nms_top_k_sharding_block: Optional[int] = None,
856857
**kwargs):
857858
"""Initializes a multi-level detection generator.
858859
@@ -873,6 +874,9 @@ def __init__(self,
873874
When soft_nms_sigma=0.0, we fall back to standard NMS.
874875
tflite_post_processing_config: An optional dictionary containing
875876
post-processing parameters used for TFLite custom NMS op.
877+
pre_nms_top_k_sharding_block: For v3 (edge tpu friendly) NMS, avoids
878+
creating long axis for pre_nms_top_k. Will do top_k in shards of size
879+
[num_classes, pre_nms_top_k_sharding_block * boxes_per_location]
876880
877881
**kwargs: Additional keyword arguments passed to Layer.
878882
"""
@@ -886,11 +890,15 @@ def __init__(self,
886890
'use_cpu_nms': use_cpu_nms,
887891
'soft_nms_sigma': soft_nms_sigma
888892
}
893+
# Don't store if were not defined
894+
if pre_nms_top_k_sharding_block is not None:
895+
self._config_dict[
896+
'pre_nms_top_k_sharding_block'] = pre_nms_top_k_sharding_block
889897

890898
if tflite_post_processing_config is not None:
891899
self._config_dict.update(
892900
{'tflite_post_processing_config': tflite_post_processing_config})
893-
super(MultilevelDetectionGenerator, self).__init__(**kwargs)
901+
super().__init__(**kwargs)
894902

895903
def _decode_multilevel_outputs(
896904
self,
@@ -969,12 +977,74 @@ def _decode_multilevel_outputs(
969977

970978
return boxes, scores, attributes
971979

972-
def __call__(self,
973-
raw_boxes: Mapping[str, tf.Tensor],
974-
raw_scores: Mapping[str, tf.Tensor],
975-
anchor_boxes: Mapping[str, tf.Tensor],
976-
image_shape: tf.Tensor,
977-
raw_attributes: Optional[Mapping[str, tf.Tensor]] = None):
980+
def _decode_multilevel_outputs_and_pre_nms_top_k(
981+
self, raw_boxes: Mapping[str, tf.Tensor],
982+
raw_scores: Mapping[str, tf.Tensor], anchor_boxes: Mapping[str,
983+
tf.Tensor],
984+
image_shape: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
985+
"""Collects dict of multilevel boxes, scores into lists."""
986+
boxes = None
987+
scores = None
988+
989+
pre_nms_top_k = self._config_dict['pre_nms_top_k']
990+
# TODO(b/258007436): consider removing when compiler be able to handle
991+
# it on its own.
992+
pre_nms_top_k_sharding_block = self._config_dict.get(
993+
'pre_nms_top_k_sharding_block', 128)
994+
levels = list(raw_boxes.keys())
995+
min_level = int(min(levels))
996+
max_level = int(max(levels))
997+
for i in range(max_level, min_level - 1, -1):
998+
(_, unsharded_h, unsharded_w, num_anchors_per_locations_times_4
999+
) = raw_boxes[str(i)].get_shape().as_list()
1000+
block = max(1, pre_nms_top_k_sharding_block // unsharded_w)
1001+
anchor_boxes_unsharded = tf.reshape(
1002+
anchor_boxes[str(i)],
1003+
[1, unsharded_h, unsharded_w, num_anchors_per_locations_times_4])
1004+
for (raw_scores_i, raw_boxes_i, anchor_boxes_i) in edgetpu.shard_tensors(
1005+
1, block,
1006+
(raw_scores[str(i)], raw_boxes[str(i)], anchor_boxes_unsharded)):
1007+
batch_size = tf.shape(raw_boxes_i)[0]
1008+
(_, feature_h_i, feature_w_i, _) = raw_boxes_i.get_shape().as_list()
1009+
num_locations = feature_h_i * feature_w_i
1010+
num_anchors_per_locations = num_anchors_per_locations_times_4 // 4
1011+
num_classes = raw_scores_i.get_shape().as_list(
1012+
)[-1] // num_anchors_per_locations
1013+
1014+
# Applies score transformation and remove the implicit background class.
1015+
scores_i = tf.slice(
1016+
tf.transpose(
1017+
tf.reshape(raw_scores_i, [
1018+
batch_size, num_locations * num_anchors_per_locations,
1019+
num_classes
1020+
]), [0, 2, 1]), [0, 1, 0], [-1, -1, -1])
1021+
1022+
# Box decoding.
1023+
# The anchor boxes are shared for all data in a batch.
1024+
# One stage detector only supports class agnostic box regression.
1025+
boxes_shape = [batch_size, num_locations * num_anchors_per_locations, 4]
1026+
boxes_i = tf.tile(
1027+
tf.expand_dims(
1028+
box_ops.decode_boxes(
1029+
tf.reshape(raw_boxes_i, boxes_shape),
1030+
tf.reshape(anchor_boxes_i, boxes_shape)),
1031+
axis=1), [1, num_classes - 1, 1, 1])
1032+
scores, boxes = edgetpu.concat_and_top_k(pre_nms_top_k,
1033+
(scores, scores_i),
1034+
(boxes, boxes_i))
1035+
1036+
return (box_ops.clip_boxes(boxes,
1037+
tf.expand_dims(image_shape,
1038+
axis=1)), tf.sigmoid(scores))
1039+
1040+
def __call__(
1041+
self,
1042+
raw_boxes: Mapping[str, tf.Tensor],
1043+
raw_scores: Mapping[str, tf.Tensor],
1044+
anchor_boxes: Mapping[str, tf.Tensor],
1045+
image_shape: tf.Tensor,
1046+
raw_attributes: Optional[Mapping[str, tf.Tensor]] = None
1047+
) -> Mapping[str, Any]:
9781048
"""Generates final detections.
9791049
9801050
Args:
@@ -1031,8 +1101,13 @@ def __call__(self,
10311101
'detection_scores': scores
10321102
}
10331103

1034-
boxes, scores, attributes = self._decode_multilevel_outputs(
1035-
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
1104+
if self._config_dict['nms_version'] != 'v3':
1105+
boxes, scores, attributes = self._decode_multilevel_outputs(
1106+
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
1107+
else:
1108+
attributes = None
1109+
boxes, scores = self._decode_multilevel_outputs_and_pre_nms_top_k(
1110+
raw_boxes, raw_scores, anchor_boxes, image_shape)
10361111

10371112
if not self._config_dict['apply_nms']:
10381113
return {
@@ -1086,17 +1161,15 @@ def __call__(self,
10861161
# Set `nmsed_attributes` to None for v2.
10871162
nmsed_attributes = {}
10881163
elif self._config_dict['nms_version'] == 'v3':
1089-
# TODO(tohaspiridonov): add compatible version of
1090-
# `_decode_multilevel_outputs` in cl/485381750
10911164
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
10921165
_generate_detections_v3(
1093-
tf.transpose(boxes, [0, 2, 1, 3]),
1094-
tf.transpose(scores, [0, 2, 1]),
1166+
boxes,
1167+
scores,
10951168
pre_nms_score_threshold=self
10961169
._config_dict['pre_nms_score_threshold'],
10971170
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
10981171
max_num_detections=self._config_dict['max_num_detections']))
1099-
# Set `nmsed_attributes` to None for v2.
1172+
# Set `nmsed_attributes` to None for v3.
11001173
nmsed_attributes = {}
11011174
else:
11021175
raise ValueError('NMS version {} not supported.'.format(

official/vision/modeling/layers/detection_generator_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,69 @@ def testDetectionsOutputShape(self, nms_version, has_att_heads, use_cpu_nms,
255255
self.assertEqual(att.numpy().shape,
256256
(batch_size, max_num_detections, 1))
257257

258+
def test_decode_multilevel_outputs_and_pre_nms_top_k(self):
259+
named_params = {
260+
'apply_nms': True,
261+
'pre_nms_top_k': 5,
262+
'pre_nms_score_threshold': 0.05,
263+
'nms_iou_threshold': 0.5,
264+
'max_num_detections': 2,
265+
'nms_version': 'v3',
266+
'use_cpu_nms': False,
267+
'soft_nms_sigma': None,
268+
}
269+
generator = detection_generator.MultilevelDetectionGenerator(**named_params)
270+
# 2 classes, 3 boxes per pixel, 2 levels '1': 2x2, '2':1x1
271+
background = [1, 0, 0]
272+
first = [0, 1, 0]
273+
second = [0, 0, 1]
274+
some = [0, 0.5, 0.5]
275+
class_outputs = {
276+
'1':
277+
tf.constant([[[
278+
first + background + first, first + background + second
279+
], [second + background + first, second + background + second]]],
280+
dtype=tf.float32),
281+
'2':
282+
tf.constant([[[background + some + background]]], dtype=tf.float32),
283+
}
284+
box_outputs = {
285+
'1': tf.zeros(shape=[1, 2, 2, 12], dtype=tf.float32),
286+
'2': tf.zeros(shape=[1, 1, 1, 12], dtype=tf.float32)
287+
}
288+
anchor_boxes = {
289+
'1':
290+
tf.random.uniform(
291+
shape=[2, 2, 12], minval=1., maxval=99., dtype=tf.float32),
292+
'2':
293+
tf.random.uniform(
294+
shape=[1, 1, 12], minval=1., maxval=99., dtype=tf.float32),
295+
}
296+
boxes, scores = generator._decode_multilevel_outputs_and_pre_nms_top_k(
297+
box_outputs, class_outputs, anchor_boxes,
298+
tf.constant([[100, 100]], dtype=tf.float32))
299+
self.assertAllClose(
300+
scores,
301+
tf.sigmoid(
302+
tf.constant([[[1, 1, 1, 1, 0.5], [1, 1, 1, 1, 0.5]]],
303+
dtype=tf.float32)))
304+
self.assertAllClose(
305+
tf.squeeze(boxes),
306+
tf.stack([
307+
# Where the first is + some as last
308+
tf.stack([
309+
anchor_boxes['1'][0, 0, 0:4], anchor_boxes['1'][0, 0, 8:12],
310+
anchor_boxes['1'][0, 1, 0:4], anchor_boxes['1'][1, 0, 8:12],
311+
anchor_boxes['2'][0, 0, 4:8]
312+
]),
313+
# Where the second is + some as last
314+
tf.stack([
315+
anchor_boxes['1'][0, 1, 8:12], anchor_boxes['1'][1, 0, 0:4],
316+
anchor_boxes['1'][1, 1, 0:4], anchor_boxes['1'][1, 1, 8:12],
317+
anchor_boxes['2'][0, 0, 4:8]
318+
]),
319+
]))
320+
258321
def test_serialize_deserialize(self):
259322
tflite_post_processing_config = {
260323
'max_detections': 100,

0 commit comments

Comments
 (0)