Skip to content

Commit d1272f9

Browse files
tensorflower-gardenerfyangf
authored andcommitted
Internal change
PiperOrigin-RevId: 500453807
1 parent f4f994e commit d1272f9

File tree

3 files changed

+92
-34
lines changed

3 files changed

+92
-34
lines changed

official/vision/modeling/layers/detection_generator.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
# limitations under the License.
1414

1515
"""Contains definitions of generators to generate the final detections."""
16+
from collections.abc import Mapping, Sequence
1617
import contextlib
17-
from typing import Any, Dict, List, Optional, Mapping, Sequence, Tuple
18+
from typing import Any, Dict, List, Optional, Tuple
19+
1820
# Import libraries
21+
22+
import numpy as np
1923
import tensorflow as tf
2024

2125
from official.vision.modeling.layers import edgetpu
@@ -411,6 +415,7 @@ def _generate_detections_v3(
411415
Raises:
412416
ValueError if inputs shapes are not valid.
413417
"""
418+
one = tf.constant(1, dtype=scores.dtype)
414419
with tf.name_scope('generate_detections'):
415420
batch_size, num_box_classes, box_locations, sides = (
416421
boxes.get_shape().as_list())
@@ -436,14 +441,14 @@ def _generate_detections_v3(
436441
# Gather NMS-ed boxes and scores.
437442
safe_indices = tf.nn.relu(indices) # 0 for invalid
438443
invalid_detections = safe_indices - indices # 1 for invalid, 0 for valid
439-
valid_detections = 1.0 - invalid_detections # 0 for invalid, 1 for valid
444+
valid_detections = one - invalid_detections # 0 for invalid, 1 for valid
440445
safe_indices = tf.cast(safe_indices, tf.int32)
441-
boxes = tf.expand_dims(valid_detections, -1) * tf.gather(
442-
boxes, safe_indices, axis=2, batch_dims=2)
446+
boxes = tf.gather(boxes, safe_indices, axis=2, batch_dims=2)
447+
boxes = tf.cast(tf.expand_dims(valid_detections, -1), boxes.dtype) * boxes
443448
scores = valid_detections * tf.gather(
444449
scores, safe_indices, axis=2, batch_dims=2)
445450
# Compliment with class numbers.
446-
classes = tf.range(num_classes, dtype=tf.float32)
451+
classes = tf.constant(np.arange(num_classes), dtype=scores.dtype)
447452
classes = tf.reshape(classes, [1, num_classes, 1])
448453
classes = tf.tile(classes, [batch_size, 1, max_num_detections])
449454
# Flatten classes, locations. Class = -1 for invalid detection
@@ -456,7 +461,7 @@ def _generate_detections_v3(
456461
boxes = tf.gather(boxes, indices, batch_dims=1, axis=1)
457462
classes = tf.gather(classes, indices, batch_dims=1, axis=1)
458463
invalid_detections = tf.nn.relu(classes) - classes
459-
valid_detections = tf.reduce_sum(1. - invalid_detections, axis=1)
464+
valid_detections = tf.reduce_sum(one - invalid_detections, axis=1)
460465
return boxes, scores, classes, valid_detections
461466

462467

@@ -995,16 +1000,18 @@ def _decode_multilevel_outputs_and_pre_nms_top_k(
9951000
min_level = int(min(levels))
9961001
max_level = int(max(levels))
9971002
for i in range(max_level, min_level - 1, -1):
998-
(_, unsharded_h, unsharded_w, num_anchors_per_locations_times_4
1003+
(batch_size, unsharded_h, unsharded_w, num_anchors_per_locations_times_4
9991004
) = raw_boxes[str(i)].get_shape().as_list()
1005+
if batch_size is None:
1006+
batch_size = tf.shape(raw_boxes[str(i)])[0]
10001007
block = max(1, pre_nms_top_k_sharding_block // unsharded_w)
1001-
anchor_boxes_unsharded = tf.reshape(
1002-
anchor_boxes[str(i)],
1003-
[1, unsharded_h, unsharded_w, num_anchors_per_locations_times_4])
1008+
anchor_boxes_unsharded = tf.reshape(anchor_boxes[str(i)], [
1009+
batch_size, unsharded_h, unsharded_w,
1010+
num_anchors_per_locations_times_4
1011+
])
10041012
for (raw_scores_i, raw_boxes_i, anchor_boxes_i) in edgetpu.shard_tensors(
10051013
1, block,
10061014
(raw_scores[str(i)], raw_boxes[str(i)], anchor_boxes_unsharded)):
1007-
batch_size = tf.shape(raw_boxes_i)[0]
10081015
(_, feature_h_i, feature_w_i, _) = raw_boxes_i.get_shape().as_list()
10091016
num_locations = feature_h_i * feature_w_i
10101017
num_anchors_per_locations = num_anchors_per_locations_times_4 // 4
@@ -1032,10 +1039,8 @@ def _decode_multilevel_outputs_and_pre_nms_top_k(
10321039
scores, boxes = edgetpu.concat_and_top_k(pre_nms_top_k,
10331040
(scores, scores_i),
10341041
(boxes, boxes_i))
1035-
1036-
return (box_ops.clip_boxes(boxes,
1037-
tf.expand_dims(image_shape,
1038-
axis=1)), tf.sigmoid(scores))
1042+
clip_shape = tf.expand_dims(tf.expand_dims(image_shape, axis=1), axis=1)
1043+
return box_ops.clip_boxes(boxes, clip_shape), tf.sigmoid(scores)
10391044

10401045
def __call__(
10411046
self,

official/vision/modeling/layers/edgetpu.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
# limitations under the License.
1414

1515
"""EdgeTPU oriented layers and tools."""
16-
1716
from collections.abc import Iterable, Sequence
18-
from typing import Optional
17+
from typing import List, Optional, Union
1918

2019
import numpy as np
2120
import tensorflow as tf
@@ -51,7 +50,8 @@ def _tensor_product_iou(boxes):
5150
# - last dimension is not 1. (Structure alignment)
5251
tpu_friendly_shape = [1, -1, 1, boxes_size]
5352
bottom, left, top, right = (
54-
tf.reshape(side, tpu_friendly_shape) for side in tf.split(boxes, 4, -1))
53+
tf.reshape(side, tpu_friendly_shape)
54+
for side in tf.split(boxes, 4, -1))
5555
height, width = top - bottom, right - left
5656
area = height * width
5757
area_sum = _tensor_sum_vectors(area, area)
@@ -116,6 +116,8 @@ def shard_tensors(axis: int, block_size: int,
116116
Raises:
117117
ValueError: if input tensors has different size of sharded dimension.
118118
"""
119+
if not all(tensor.shape.is_fully_defined() for tensor in tensors):
120+
return [tensors]
119121
for validate_axis in range(axis + 1):
120122
consistent_length: int = tensors[0].shape[validate_axis]
121123
for tensor in tensors:
@@ -195,6 +197,8 @@ def non_max_suppression_padded(boxes: tf.Tensor,
195197
A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing
196198
the selected indices from the boxes tensor and `-1` values for the padding.
197199
"""
200+
if not boxes.shape.is_fully_defined():
201+
return _non_max_suppression_as_is(boxes, scores, output_size, iou_threshold)
198202
# Does partitioning job to help compiler converge with memory.
199203
batch_shape = boxes.shape[:-2]
200204
batch_size = np.prod(batch_shape, dtype=np.int32)
@@ -211,6 +215,52 @@ def non_max_suppression_padded(boxes: tf.Tensor,
211215
return tf.reshape(indices, batch_shape + [output_size])
212216

213217

218+
def _refine_nms_graph_to_original_algorithm(better: tf.Tensor) -> tf.Tensor:
219+
"""Refines the relationship graph, bringing it closer to the iterative NMS.
220+
221+
See `test_refinement_sample` unit tests for example, also comments in body of
222+
the algorithm, for the intuition.
223+
224+
Args:
225+
better: is a tensor with zeros and ones so that [batch dims ..., box_1,
226+
box_2] represents the [adjacency
227+
matrix](https://en.wikipedia.org/wiki/Adjacency_matrix) for the
228+
[relation](https://en.wikipedia.org/wiki/Relation_(mathematics)) `better`
229+
between boxes box_1 and box_2.
230+
231+
Returns:
232+
Modification of tensor encoding adjacency matrix of `better` relation.
233+
"""
234+
# good_box: is a tensor with zeros and ones so that
235+
# [batch dims ..., box_i] represents belonging of a box_i to the `good`
236+
# subset. `good` subset is defined as exactly those boxes that do not have any
237+
# `better` boxes.
238+
# INTUITION: In terms of oriented graph , this is subset of nodes nobody
239+
# points to as "I'm better than you". These nodes will never be suppressed in
240+
# the original NMS algorithm.
241+
good_box = tf.constant(1.) - _reduce_or(better, axis=-1)
242+
# good_better: is a tensor with zeros and ones so that
243+
# [batch dims ..., box_1, box_2] represents the adjacency matrix for the
244+
# `good_better` relation on all boxes set. `good_better` relation is defined
245+
# as relation between good box and boxes it is better than.
246+
# INTUITION: In terms of oriented graph, this is subset of edges, which
247+
# doesn't have any other inbound edges. These edges will represent
248+
# suppression actions in the original NMS algorithm.
249+
good_better = _and(tf.expand_dims(good_box, axis=-2), better)
250+
# not_bad_box: is a tensor with zeros and ones so that
251+
# [batch dims ..., box_i] represents belonging of a box_i to the `not_bad`
252+
# subset. `not_bad` subset is defined as boxes all that and only those that
253+
# does not have any `good_better` boxes.
254+
# INTUITION: These nodes are nodes which are not suppressed by `good` boxes
255+
# in the original NMS algorithm.
256+
not_bad_box = tf.constant(1.) - _reduce_or(good_better, axis=-1)
257+
# return: is a tensor with zeros and ones so that
258+
# [batch dims ..., box_1, box_2] represents the adjacency matrix for the
259+
# `better` relation on all boxes set which is closer to represent suppression
260+
# procedure in original NMS algorithm.
261+
return _and(tf.expand_dims(not_bad_box, axis=-2), better)
262+
263+
214264
def _non_max_suppression_as_is(boxes: tf.Tensor,
215265
scores: tf.Tensor,
216266
output_size: int,
@@ -230,32 +280,34 @@ def _non_max_suppression_as_is(boxes: tf.Tensor,
230280
A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing
231281
the selected indices from the boxes tensor and `-1` values for the padding.
232282
"""
233-
batch_shape = boxes.shape[:-2]
234-
batch_size = np.prod(batch_shape, dtype=np.int32)
235283
boxes_size = boxes.shape[-2]
236284
if boxes.shape[-1] != 4:
237285
raise ValueError(f'Boxes shape ({boxes.shape}) last dimension must be 4 '
238286
'to represent [y1, x1, y2, x2] boxes coordinates')
239287
if scores.shape != boxes.shape[:-1]:
240288
raise ValueError(f'Boxes shape ({boxes.shape}) and scores shape '
241289
f'({scores.shape}) do not match.')
242-
order = tf.range(boxes_size, dtype=tf.float32)
290+
order = tf.constant(np.arange(boxes_size), dtype=scores.dtype)
243291
relative_order = _tensor_sum_vectors(order, -order)
244292
relative_scores = _tensor_sum_vectors(scores, -scores)
245-
similar = _greater(_tensor_product_iou(boxes) - iou_threshold)
293+
similar = tf.cast(
294+
_greater(
295+
_tensor_product_iou(boxes) -
296+
tf.constant(iou_threshold, dtype=boxes.dtype)), scores.dtype)
246297
worse = _greater(relative_scores)
247298
same_later = _and(_same(relative_scores), _greater(relative_order))
248299
similar_worse_or_same_later = _and(similar, _or(worse, same_later))
249300
prunable = _reduce_or(similar_worse_or_same_later, axis=-1)
250-
remaining = tf.constant(1.) - prunable
251-
scores = tf.reshape(tf.exp(scores), [1, 1, batch_size, boxes_size])
252-
remaining = tf.reshape(remaining, [1, 1, batch_size, boxes_size])
301+
remaining = tf.constant(1, dtype=prunable.dtype) - prunable
302+
if scores.shape[0] is None:
303+
# Prefer the most of tesnor shape defined, so that error messages are clear.
304+
remaining = tf.reshape(remaining, [tf.shape(scores)[0], *scores.shape[1:]])
305+
else:
306+
remaining = tf.reshape(remaining, scores.shape)
253307
# top_k runs on TPU cores, let it happen, TPU tiles implementation is slower.
254308
top_k = tf.math.top_k(scores * remaining, output_size)
255-
indices = (
256-
tf.cast(top_k.indices, top_k.values.dtype) * _greater(top_k.values) -
257-
_same(top_k.values))
258-
return tf.reshape(indices, batch_shape + [output_size])
309+
return (tf.cast(top_k.indices, top_k.values.dtype) * _greater(top_k.values) -
310+
_same(top_k.values))
259311

260312

261313
def concat_and_top_k(

official/vision/serving/detection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@ class DetectionModule(export_base.ExportModule):
3232

3333
def _build_model(self):
3434

35-
if self._batch_size is None:
36-
# Only batched NMS is supported with dynamic batch size.
35+
nms_versions_supporting_dynamic_batch_size = {'batched', 'v3'}
36+
nms_version = self.params.task.model.detection_generator.nms_version
37+
if (self._batch_size is None and
38+
nms_version not in nms_versions_supporting_dynamic_batch_size):
39+
logging.info('nms_version is set to `batched` because `%s` '
40+
'does not support with dynamic batch size.', nms_version)
3741
self.params.task.model.detection_generator.nms_version = 'batched'
38-
logging.info(
39-
'nms_version is set to `batched` because only batched NMS is '
40-
'supported with dynamic batch size.')
4142

4243
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
4344
self._input_image_size + [3])

0 commit comments

Comments
 (0)