13
13
# limitations under the License.
14
14
15
15
"""EdgeTPU oriented layers and tools."""
16
-
17
16
from collections .abc import Iterable , Sequence
18
- from typing import Optional
17
+ from typing import List , Optional , Union
19
18
20
19
import numpy as np
21
20
import tensorflow as tf
@@ -51,7 +50,8 @@ def _tensor_product_iou(boxes):
51
50
# - last dimension is not 1. (Structure alignment)
52
51
tpu_friendly_shape = [1 , - 1 , 1 , boxes_size ]
53
52
bottom , left , top , right = (
54
- tf .reshape (side , tpu_friendly_shape ) for side in tf .split (boxes , 4 , - 1 ))
53
+ tf .reshape (side , tpu_friendly_shape )
54
+ for side in tf .split (boxes , 4 , - 1 ))
55
55
height , width = top - bottom , right - left
56
56
area = height * width
57
57
area_sum = _tensor_sum_vectors (area , area )
@@ -116,6 +116,8 @@ def shard_tensors(axis: int, block_size: int,
116
116
Raises:
117
117
ValueError: if input tensors has different size of sharded dimension.
118
118
"""
119
+ if not all (tensor .shape .is_fully_defined () for tensor in tensors ):
120
+ return [tensors ]
119
121
for validate_axis in range (axis + 1 ):
120
122
consistent_length : int = tensors [0 ].shape [validate_axis ]
121
123
for tensor in tensors :
@@ -195,6 +197,8 @@ def non_max_suppression_padded(boxes: tf.Tensor,
195
197
A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing
196
198
the selected indices from the boxes tensor and `-1` values for the padding.
197
199
"""
200
+ if not boxes .shape .is_fully_defined ():
201
+ return _non_max_suppression_as_is (boxes , scores , output_size , iou_threshold )
198
202
# Does partitioning job to help compiler converge with memory.
199
203
batch_shape = boxes .shape [:- 2 ]
200
204
batch_size = np .prod (batch_shape , dtype = np .int32 )
@@ -211,6 +215,52 @@ def non_max_suppression_padded(boxes: tf.Tensor,
211
215
return tf .reshape (indices , batch_shape + [output_size ])
212
216
213
217
218
+ def _refine_nms_graph_to_original_algorithm (better : tf .Tensor ) -> tf .Tensor :
219
+ """Refines the relationship graph, bringing it closer to the iterative NMS.
220
+
221
+ See `test_refinement_sample` unit tests for example, also comments in body of
222
+ the algorithm, for the intuition.
223
+
224
+ Args:
225
+ better: is a tensor with zeros and ones so that [batch dims ..., box_1,
226
+ box_2] represents the [adjacency
227
+ matrix](https://en.wikipedia.org/wiki/Adjacency_matrix) for the
228
+ [relation](https://en.wikipedia.org/wiki/Relation_(mathematics)) `better`
229
+ between boxes box_1 and box_2.
230
+
231
+ Returns:
232
+ Modification of tensor encoding adjacency matrix of `better` relation.
233
+ """
234
+ # good_box: is a tensor with zeros and ones so that
235
+ # [batch dims ..., box_i] represents belonging of a box_i to the `good`
236
+ # subset. `good` subset is defined as exactly those boxes that do not have any
237
+ # `better` boxes.
238
+ # INTUITION: In terms of oriented graph , this is subset of nodes nobody
239
+ # points to as "I'm better than you". These nodes will never be suppressed in
240
+ # the original NMS algorithm.
241
+ good_box = tf .constant (1. ) - _reduce_or (better , axis = - 1 )
242
+ # good_better: is a tensor with zeros and ones so that
243
+ # [batch dims ..., box_1, box_2] represents the adjacency matrix for the
244
+ # `good_better` relation on all boxes set. `good_better` relation is defined
245
+ # as relation between good box and boxes it is better than.
246
+ # INTUITION: In terms of oriented graph, this is subset of edges, which
247
+ # doesn't have any other inbound edges. These edges will represent
248
+ # suppression actions in the original NMS algorithm.
249
+ good_better = _and (tf .expand_dims (good_box , axis = - 2 ), better )
250
+ # not_bad_box: is a tensor with zeros and ones so that
251
+ # [batch dims ..., box_i] represents belonging of a box_i to the `not_bad`
252
+ # subset. `not_bad` subset is defined as boxes all that and only those that
253
+ # does not have any `good_better` boxes.
254
+ # INTUITION: These nodes are nodes which are not suppressed by `good` boxes
255
+ # in the original NMS algorithm.
256
+ not_bad_box = tf .constant (1. ) - _reduce_or (good_better , axis = - 1 )
257
+ # return: is a tensor with zeros and ones so that
258
+ # [batch dims ..., box_1, box_2] represents the adjacency matrix for the
259
+ # `better` relation on all boxes set which is closer to represent suppression
260
+ # procedure in original NMS algorithm.
261
+ return _and (tf .expand_dims (not_bad_box , axis = - 2 ), better )
262
+
263
+
214
264
def _non_max_suppression_as_is (boxes : tf .Tensor ,
215
265
scores : tf .Tensor ,
216
266
output_size : int ,
@@ -230,32 +280,34 @@ def _non_max_suppression_as_is(boxes: tf.Tensor,
230
280
A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing
231
281
the selected indices from the boxes tensor and `-1` values for the padding.
232
282
"""
233
- batch_shape = boxes .shape [:- 2 ]
234
- batch_size = np .prod (batch_shape , dtype = np .int32 )
235
283
boxes_size = boxes .shape [- 2 ]
236
284
if boxes .shape [- 1 ] != 4 :
237
285
raise ValueError (f'Boxes shape ({ boxes .shape } ) last dimension must be 4 '
238
286
'to represent [y1, x1, y2, x2] boxes coordinates' )
239
287
if scores .shape != boxes .shape [:- 1 ]:
240
288
raise ValueError (f'Boxes shape ({ boxes .shape } ) and scores shape '
241
289
f'({ scores .shape } ) do not match.' )
242
- order = tf .range ( boxes_size , dtype = tf . float32 )
290
+ order = tf .constant ( np . arange ( boxes_size ) , dtype = scores . dtype )
243
291
relative_order = _tensor_sum_vectors (order , - order )
244
292
relative_scores = _tensor_sum_vectors (scores , - scores )
245
- similar = _greater (_tensor_product_iou (boxes ) - iou_threshold )
293
+ similar = tf .cast (
294
+ _greater (
295
+ _tensor_product_iou (boxes ) -
296
+ tf .constant (iou_threshold , dtype = boxes .dtype )), scores .dtype )
246
297
worse = _greater (relative_scores )
247
298
same_later = _and (_same (relative_scores ), _greater (relative_order ))
248
299
similar_worse_or_same_later = _and (similar , _or (worse , same_later ))
249
300
prunable = _reduce_or (similar_worse_or_same_later , axis = - 1 )
250
- remaining = tf .constant (1. ) - prunable
251
- scores = tf .reshape (tf .exp (scores ), [1 , 1 , batch_size , boxes_size ])
252
- remaining = tf .reshape (remaining , [1 , 1 , batch_size , boxes_size ])
301
+ remaining = tf .constant (1 , dtype = prunable .dtype ) - prunable
302
+ if scores .shape [0 ] is None :
303
+ # Prefer the most of tesnor shape defined, so that error messages are clear.
304
+ remaining = tf .reshape (remaining , [tf .shape (scores )[0 ], * scores .shape [1 :]])
305
+ else :
306
+ remaining = tf .reshape (remaining , scores .shape )
253
307
# top_k runs on TPU cores, let it happen, TPU tiles implementation is slower.
254
308
top_k = tf .math .top_k (scores * remaining , output_size )
255
- indices = (
256
- tf .cast (top_k .indices , top_k .values .dtype ) * _greater (top_k .values ) -
257
- _same (top_k .values ))
258
- return tf .reshape (indices , batch_shape + [output_size ])
309
+ return (tf .cast (top_k .indices , top_k .values .dtype ) * _greater (top_k .values ) -
310
+ _same (top_k .values ))
259
311
260
312
261
313
def concat_and_top_k (
0 commit comments