|
| 1 | +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""EdgeTPU oriented layers and tools.""" |
| 16 | + |
| 17 | +from collections.abc import Iterable, Sequence |
| 18 | +from typing import Optional |
| 19 | + |
| 20 | +import numpy as np |
| 21 | +import tensorflow as tf |
| 22 | + |
| 23 | +_or = tf.maximum |
| 24 | +_and = tf.minimum |
| 25 | +_reduce_or = tf.reduce_max |
| 26 | + |
| 27 | + |
| 28 | +def _tensor_sum_vectors(a, b): |
| 29 | + a = tf.tile(tf.reshape(a, [1, -1, 1, a.shape[-1]]), [1, 1, a.shape[-1], 1]) |
| 30 | + b = tf.tile(tf.reshape(b, [1, -1, a.shape[-1], 1]), [1, 1, 1, a.shape[-1]]) |
| 31 | + return a + b |
| 32 | + |
| 33 | + |
| 34 | +def _tensor_product_iou(boxes): |
| 35 | + """Computes pairwise IOU. |
| 36 | +
|
| 37 | + Reason to use 4-D tensors is to follow TPU compiler preference. |
| 38 | +
|
| 39 | + Args: |
| 40 | + boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + A 4-D float `Tensor` of shape `[1, 1, num_boxes, num_boxes]` containing |
| 44 | + pairwise IOU. |
| 45 | + """ |
| 46 | + boxes_size = boxes.shape[-2] |
| 47 | + # Code below will do frequent operands broadcasting. |
| 48 | + # TPU compiler has (empirically) less issues broadcasting if |
| 49 | + # - batch (first) dimension is 1. (Special consideration sharding) |
| 50 | + # - there are 4 dimensions. (Standard traversal mapping) |
| 51 | + # - last dimension is not 1. (Structure alignment) |
| 52 | + tpu_friendly_shape = [1, -1, 1, boxes_size] |
| 53 | + bottom, left, top, right = ( |
| 54 | + tf.reshape(side, tpu_friendly_shape) for side in tf.split(boxes, 4, -1)) |
| 55 | + height, width = top - bottom, right - left |
| 56 | + area = height * width |
| 57 | + area_sum = _tensor_sum_vectors(area, area) |
| 58 | + bottom_pad, left_pad, top_pad, right_pad = ( |
| 59 | + tf.nn.relu(_tensor_sum_vectors(x, -x)) |
| 60 | + for x in (-bottom, -left, top, right)) |
| 61 | + height_pad, width_pad = bottom_pad + top_pad, left_pad + right_pad |
| 62 | + intersection = tf.nn.relu(height - height_pad) * tf.nn.relu(width - width_pad) |
| 63 | + union = area_sum - intersection |
| 64 | + iou = tf.math.divide(intersection, union + _same(union)) |
| 65 | + return iou |
| 66 | + |
| 67 | + |
| 68 | +def _greater(x): |
| 69 | + """Avoid non lowerable layers in boolean comparison. |
| 70 | +
|
| 71 | + Logical operation results in tensor of boolean type. However in serving such |
| 72 | + a tensors cannot be cast to values because of NNAPI specs. |
| 73 | + `tf.where` operation result in `select` instruction lowering, which not runs |
| 74 | + well on all generations of edge-tpus. |
| 75 | +
|
| 76 | + Args: |
| 77 | + x: any numeric tensor. |
| 78 | +
|
| 79 | + Returns: |
| 80 | + tf.where(x > tf.zero_like(x), tf.one_like(x), tf.zero_like(x)) |
| 81 | + """ |
| 82 | + x_clip = tf.minimum(tf.nn.relu(x), tf.constant(1, dtype=x.dtype)) |
| 83 | + return -tf.math.floor(-x_clip) |
| 84 | + |
| 85 | + |
| 86 | +def _same(x): |
| 87 | + """Avoid non lowerable layers in boolean equality. |
| 88 | +
|
| 89 | + Logical operation results in tensor of boolean type. However in serving such |
| 90 | + a tensors cannot be cast to values because of NNAPI specs. |
| 91 | + `tf.where` operation result in `select` instruction lowering, which not runs |
| 92 | + well on all generations of edge-tpus. |
| 93 | +
|
| 94 | + Args: |
| 95 | + x: any numeric tensor. |
| 96 | +
|
| 97 | + Returns: |
| 98 | + tf.where(x == tf.zero_like(x), tf.one_like(x), tf.zero_like(x)) |
| 99 | + """ |
| 100 | + x_clip = tf.minimum(tf.abs(x), tf.constant(1, dtype=x.dtype)) |
| 101 | + return tf.constant(1, dtype=x.dtype) + tf.math.floor(-x_clip) |
| 102 | + |
| 103 | + |
| 104 | +def shard_tensors(axis: int, block_size: int, |
| 105 | + *tensors: tf.Tensor) -> Iterable[Sequence[tf.Tensor]]: |
| 106 | + """Consistently splits multiple tensors sharding-style. |
| 107 | +
|
| 108 | + Args: |
| 109 | + axis: axis to be used to split tensors |
| 110 | + block_size: block size to split tensors. |
| 111 | + *tensors: list of tensors. |
| 112 | +
|
| 113 | + Returns: |
| 114 | + List of shards, each shard has exactly one peace of each input tesnor. |
| 115 | +
|
| 116 | + Raises: |
| 117 | + ValueError: if input tensors has different size of sharded dimension. |
| 118 | + """ |
| 119 | + for validate_axis in range(axis + 1): |
| 120 | + consistent_length: int = tensors[0].shape[validate_axis] |
| 121 | + for tensor in tensors: |
| 122 | + if tensor.shape[validate_axis] != consistent_length: |
| 123 | + raise ValueError('Inconsistent shapes in shard_tensors: first is ' |
| 124 | + f'{tensors[0].shape} and other is {tensor.shape}') |
| 125 | + batch_size: int = tensors[0].shape[axis] |
| 126 | + if block_size >= batch_size: |
| 127 | + return [tensors] |
| 128 | + else: |
| 129 | + blocks = batch_size // block_size |
| 130 | + remainder = batch_size % block_size |
| 131 | + if remainder: |
| 132 | + tensor_parts = [] |
| 133 | + for tensor in tensors: |
| 134 | + shape: tf.TensorShape = tensor.shape |
| 135 | + body: tf.Tensor = tf.slice(tensor, [0] * len(shape), [ |
| 136 | + size if i != axis else blocks * block_size |
| 137 | + for i, size in enumerate(shape) |
| 138 | + ]) |
| 139 | + tail: tf.Tensor = tf.slice(tensor, [ |
| 140 | + 0 if i != axis else (blocks * block_size) |
| 141 | + for i, _ in enumerate(shape) |
| 142 | + ], [ |
| 143 | + size if i != axis else (size - blocks * block_size) |
| 144 | + for i, size in enumerate(shape) |
| 145 | + ]) |
| 146 | + tensor_parts.append(tf.split(body, blocks, axis) + [tail]) |
| 147 | + return zip(*tensor_parts) |
| 148 | + else: |
| 149 | + return zip(*[tf.split(tensor, blocks, axis) for tensor in tensors]) |
| 150 | + |
| 151 | + |
| 152 | +# TODO(b/258007436): Number is based on existing compiler limitations while |
| 153 | +# running bf16 NMS on edgetpu. Remove manual sharing when compiler issue will be |
| 154 | +# fixed. |
| 155 | +_RECOMMENDED_NMS_MEMORY = 360000 |
| 156 | + |
| 157 | + |
| 158 | +def non_max_suppression_padded(boxes: tf.Tensor, |
| 159 | + scores: tf.Tensor, |
| 160 | + output_size: int, |
| 161 | + iou_threshold: float = 0.5) -> tf.Tensor: |
| 162 | + """Selects a subset of boxes which have highest score among IOU-similar boxes. |
| 163 | +
|
| 164 | + Prunes away boxes that have high intersection-over-union (IOU) overlap |
| 165 | + with boxes having higher score. Boxes are supplied as `[y1, x1, y2, x2]`, |
| 166 | + where `(y1, x1)` and `(y2, x2)` are the coordinates of any diagonal pair of |
| 167 | + box corners. Note that this algorithm is agnostic to the coordinate system. |
| 168 | + Thus translating or reflections of the coordinate system result in the same |
| 169 | + boxes being selected by the algorithm. The output of this operation is a |
| 170 | + set of integers indexing into the input collection of bounding boxes |
| 171 | + representing the selected boxes. |
| 172 | +
|
| 173 | + Set will be returned padded on the right with `-1` values. The bounding |
| 174 | + box coordinates corresponding to the selected indices can then be obtained |
| 175 | + using the `tf.gather` operation. For example: |
| 176 | + ```python |
| 177 | + selected_indices = vision.modeling.layers.non_max_suppression_padded( |
| 178 | + boxes, scores, max_output_size, iou_threshold) |
| 179 | + selected_boxes = tf.gather(boxes, selected_indices) |
| 180 | + ``` |
| 181 | +
|
| 182 | + See following documetation for implementation details. |
| 183 | + third_party/tensorflow_models/official/projects/edgetpu/vision/modeling/g3doc/non_max_suppression.md |
| 184 | +
|
| 185 | + Args: |
| 186 | + boxes: A 2-D+ float `Tensor` of shape `[...batch_dims, num_boxes, 4]`. |
| 187 | + scores: A 1-D+ float `Tensor` of shape `[...batch_dims, num_boxes]` |
| 188 | + representing a single score corresponding to each box (each row of boxes). |
| 189 | + output_size: A scalar integer `Tensor` representing the maximum number of |
| 190 | + boxes to be selected by non-max suppression. |
| 191 | + iou_threshold: A 0-D float tensor representing the threshold for deciding |
| 192 | + whether boxes overlap too much with respect to IOU. |
| 193 | +
|
| 194 | + Returns: |
| 195 | + A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing |
| 196 | + the selected indices from the boxes tensor and `-1` values for the padding. |
| 197 | + """ |
| 198 | + # Does partitioning job to help compiler converge with memory. |
| 199 | + batch_shape = boxes.shape[:-2] |
| 200 | + batch_size = np.prod(batch_shape, dtype=np.int32) |
| 201 | + boxes_size, struct_size = boxes.shape[-2:] |
| 202 | + boxes = tf.reshape(boxes, [batch_size, boxes_size, struct_size]) |
| 203 | + scores = tf.reshape(scores, [batch_size, boxes_size]) |
| 204 | + block = max(1, _RECOMMENDED_NMS_MEMORY // (boxes_size * boxes_size)) |
| 205 | + indices = [] |
| 206 | + for boxes_i, scores_i in shard_tensors(0, block, boxes, scores): |
| 207 | + indices.append( |
| 208 | + _non_max_suppression_as_is(boxes_i, scores_i, output_size, |
| 209 | + iou_threshold)) |
| 210 | + indices = tf.concat(indices, axis=0) |
| 211 | + return tf.reshape(indices, batch_shape + [output_size]) |
| 212 | + |
| 213 | + |
| 214 | +def _non_max_suppression_as_is(boxes: tf.Tensor, |
| 215 | + scores: tf.Tensor, |
| 216 | + output_size: int, |
| 217 | + iou_threshold: float = 0.5) -> tf.Tensor: |
| 218 | + """Selects a subset of boxes which have highest score among IOU-similar boxes. |
| 219 | +
|
| 220 | + Args: |
| 221 | + boxes: A 2-D+ float `Tensor` of shape `[...batch_dims, num_boxes, 4]`. |
| 222 | + scores: A 1-D+ float `Tensor` of shape `[...batch_dims, num_boxes]` |
| 223 | + representing a single score corresponding to each box (each row of boxes). |
| 224 | + output_size: A scalar integer `Tensor` representing the maximum number of |
| 225 | + boxes to be selected by non-max suppression. |
| 226 | + iou_threshold: A 0-D float tensor representing the threshold for deciding |
| 227 | + whether boxes overlap too much with respect to IOU. |
| 228 | +
|
| 229 | + Returns: |
| 230 | + A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing |
| 231 | + the selected indices from the boxes tensor and `-1` values for the padding. |
| 232 | + """ |
| 233 | + batch_shape = boxes.shape[:-2] |
| 234 | + batch_size = np.prod(batch_shape, dtype=np.int32) |
| 235 | + boxes_size = boxes.shape[-2] |
| 236 | + if boxes.shape[-1] != 4: |
| 237 | + raise ValueError(f'Boxes shape ({boxes.shape}) last dimension must be 4 ' |
| 238 | + 'to represent [y1, x1, y2, x2] boxes coordinates') |
| 239 | + if scores.shape != boxes.shape[:-1]: |
| 240 | + raise ValueError(f'Boxes shape ({boxes.shape}) and scores shape ' |
| 241 | + f'({scores.shape}) do not match.') |
| 242 | + order = tf.range(boxes_size, dtype=tf.float32) |
| 243 | + relative_order = _tensor_sum_vectors(order, -order) |
| 244 | + relative_scores = _tensor_sum_vectors(scores, -scores) |
| 245 | + similar = _greater(_tensor_product_iou(boxes) - iou_threshold) |
| 246 | + worse = _greater(relative_scores) |
| 247 | + same_later = _and(_same(relative_scores), _greater(relative_order)) |
| 248 | + similar_worse_or_same_later = _and(similar, _or(worse, same_later)) |
| 249 | + prunable = _reduce_or(similar_worse_or_same_later, axis=-1) |
| 250 | + remaining = tf.constant(1.) - prunable |
| 251 | + scores = tf.reshape(tf.exp(scores), [1, 1, batch_size, boxes_size]) |
| 252 | + remaining = tf.reshape(remaining, [1, 1, batch_size, boxes_size]) |
| 253 | + # top_k runs on TPU cores, let it happen, TPU tiles implementation is slower. |
| 254 | + top_k = tf.math.top_k(scores * remaining, output_size) |
| 255 | + indices = ( |
| 256 | + tf.cast(top_k.indices, top_k.values.dtype) * _greater(top_k.values) - |
| 257 | + _same(top_k.values)) |
| 258 | + return tf.reshape(indices, batch_shape + [output_size]) |
| 259 | + |
| 260 | + |
| 261 | +def concat_and_top_k( |
| 262 | + top_k: int, scores_pair: tuple[Optional[tf.Tensor], tf.Tensor], |
| 263 | + *other_pairs: tuple[Optional[tf.Tensor], tf.Tensor] |
| 264 | +) -> tuple[tf.Tensor, ...]: |
| 265 | + """Combines shards of top_k operation, when sharded along filtered dimension. |
| 266 | +
|
| 267 | + General idea is that sometimes top_k dimension is very large, while top_k is |
| 268 | + moderately low. (Keep in mind sample of 15K pre-top_k dimension and 150 top_k) |
| 269 | + In that case it is possible to break top_k input into groups significantly |
| 270 | + larger than top_k and significatly lower than pre-top_l (Keep in mind 1500). |
| 271 | + We do top_k over first 1500 elements, than join 150 remaining with new 1500 |
| 272 | + elements (1750 in total), repeat top_k. This function provides repeatedly used |
| 273 | + method which will concat and top_k in that case. |
| 274 | +
|
| 275 | + For example with top_k = 2 and scores_pair = ([10, 6], [9, 8, 7]), output |
| 276 | + scores will be [10, 9]. |
| 277 | +
|
| 278 | + Other pairs are filtered using indexes generated from scores. This is a preaty |
| 279 | + common case of filtering structure by its score. |
| 280 | +
|
| 281 | + For example with one extra pair of box per score: |
| 282 | + top_k = 2 |
| 283 | + scores_pair = ([10, 6], |
| 284 | + [9, 8, 7]) |
| 285 | + other_pairs = [([[0, 0, 10, 10], [0, 0, 6, 6]], |
| 286 | + [[1, 1, 9, 9], [1, 1, 8, 8], [1, 1, 7, 7]])] |
| 287 | + Output is: |
| 288 | + ([10, 9], [[0, 0, 10, 10], [1, 1, 9, 9]]) |
| 289 | +
|
| 290 | + See also 'test_top_k_sharded_fusion' unit test with end to end example. |
| 291 | +
|
| 292 | + Args: |
| 293 | + top_k: is top_k argument of sharded tf.math.top_k. |
| 294 | + scores_pair: Tuple (<previous shards combination>, <additional shard>) |
| 295 | + scores to be aggregated using top_k. |
| 296 | + *other_pairs: Tuples (<previous shards combination>, <additional shard>) |
| 297 | + other values to be aggregated using indexes of top_k scores. |
| 298 | +
|
| 299 | + Returns: |
| 300 | + Tuple of scores based top_k aggregations with additional shards. |
| 301 | + """ |
| 302 | + scores, scores_shard = scores_pair |
| 303 | + if other_pairs: |
| 304 | + others, others_shard = zip(*other_pairs) |
| 305 | + else: |
| 306 | + others = others_shard = [] |
| 307 | + # Same as tf.rank, but avoiding tensor form for graph mode execution. |
| 308 | + top_k_dim: int = len(scores_shard.shape) - 1 |
| 309 | + if scores is None: |
| 310 | + # First shard becomes aggregation |
| 311 | + scores = scores_shard |
| 312 | + others = others_shard |
| 313 | + else: |
| 314 | + # Merge shard into aggregation |
| 315 | + scores = tf.concat([scores, scores_shard], top_k_dim) |
| 316 | + others = [ |
| 317 | + tf.concat([other, other_shard], top_k_dim) |
| 318 | + for other, other_shard in zip(others, others_shard) |
| 319 | + ] |
| 320 | + # When shards are uneven some will be smaller than requested top_k |
| 321 | + if scores.shape[top_k_dim] > top_k: |
| 322 | + scores, indices = tf.nn.top_k(scores, top_k) |
| 323 | + others = [ |
| 324 | + tf.gather(other, indices, axis=top_k_dim, batch_dims=top_k_dim) |
| 325 | + for other in others |
| 326 | + ] |
| 327 | + return scores, *others |
0 commit comments