@@ -328,20 +328,130 @@ def _select_top_k_scores(scores_in: tf.Tensor, pre_nms_num_detections: int):
328
328
)
329
329
330
330
331
- def _generate_detections_v2 (
331
+ def _generate_detections_v2_class_agnostic (
332
332
boxes : tf .Tensor ,
333
333
scores : tf .Tensor ,
334
334
pre_nms_top_k : int = 5000 ,
335
335
pre_nms_score_threshold : float = 0.05 ,
336
336
nms_iou_threshold : float = 0.5 ,
337
- max_num_detections : int = 100 ,
337
+ max_num_detections : int = 100
338
338
):
339
- """Generates the final detections given the model outputs .
339
+ """Generates the final detections by applying class-agnostic NMS .
340
340
341
- This implementation unrolls classes dimension while using the tf.while_loop
342
- to implement the batched NMS, so that it can be parallelized at the batch
343
- dimension. It should give better performance comparing to v1 implementation.
344
- It is TPU compatible.
341
+ Args:
342
+ boxes: A `tf.Tensor` with shape `[batch_size, N, num_classes, 4]` or
343
+ `[batch_size, N, 1, 4]`, which box predictions on all feature levels. The
344
+ N is the number of total anchors on all levels.
345
+ scores: A `tf.Tensor` with shape `[batch_size, N, num_classes]`, which
346
+ stacks class probability on all feature levels. The N is the number of
347
+ total anchors on all levels. The num_classes is the number of classes
348
+ predicted by the model. Note that the class_outputs here is the raw score.
349
+ pre_nms_top_k: An `int` number of top candidate detections per class before
350
+ NMS.
351
+ pre_nms_score_threshold: A `float` representing the threshold for deciding
352
+ when to remove boxes based on score.
353
+ nms_iou_threshold: A `float` representing the threshold for deciding whether
354
+ boxes overlap too much with respect to IOU.
355
+ max_num_detections: A `scalar` representing maximum number of boxes retained
356
+ over all classes.
357
+
358
+ Returns:
359
+ nms_boxes: A `float` tf.Tensor of shape [batch_size, max_num_detections, 4]
360
+ representing top detected boxes in [y1, x1, y2, x2].
361
+ nms_scores: A `float` tf.Tensor of shape [batch_size, max_num_detections]
362
+ representing sorted confidence scores for detected boxes. The values are
363
+ between [0, 1].
364
+ nms_classes: An `int` tf.Tensor of shape [batch_size, max_num_detections]
365
+ representing classes for detected boxes.
366
+ valid_detections: An `int` tf.Tensor of shape [batch_size] only the top
367
+ `valid_detections` boxes are valid detections.
368
+ """
369
+ with tf .name_scope ('generate_detections_class_agnostic' ):
370
+ nmsed_boxes = []
371
+ nmsed_classes = []
372
+ nmsed_scores = []
373
+ valid_detections = []
374
+ batch_size , _ , num_classes_for_box , _ = boxes .get_shape ().as_list ()
375
+ if batch_size is None :
376
+ batch_size = tf .shape (boxes )[0 ]
377
+ _ , total_anchors , _ = scores .get_shape ().as_list ()
378
+
379
+ # Keeps only the class with highest score for each predicted box.
380
+ scores_condensed , classes_ids = tf .nn .top_k (
381
+ scores , k = 1 , sorted = True
382
+ )
383
+ scores_condensed = tf .squeeze (scores_condensed , axis = [2 ])
384
+ if num_classes_for_box > 1 :
385
+ boxes = tf .gather (boxes , classes_ids , axis = 2 , batch_dims = 2 )
386
+ boxes_condensed = tf .squeeze (boxes , axis = [2 ])
387
+ classes_condensed = tf .squeeze (classes_ids , axis = [2 ])
388
+
389
+ # Selects top pre_nms_num scores and indices before NMS.
390
+ num_anchors_filtered = min (total_anchors , pre_nms_top_k )
391
+ scores_filtered , indices_filtered = tf .nn .top_k (
392
+ scores_condensed , k = num_anchors_filtered , sorted = True
393
+ )
394
+ classes_filtered = tf .gather (
395
+ classes_condensed , indices_filtered , axis = 1 , batch_dims = 1
396
+ )
397
+ boxes_filtered = tf .gather (
398
+ boxes_condensed , indices_filtered , axis = 1 , batch_dims = 1
399
+ )
400
+
401
+ tf .ensure_shape (boxes_filtered , [None , num_anchors_filtered , 4 ])
402
+ tf .ensure_shape (classes_filtered , [None , num_anchors_filtered ])
403
+ tf .ensure_shape (scores_filtered , [None , num_anchors_filtered ])
404
+ boxes_filtered = tf .cast (
405
+ boxes_filtered , tf .float32
406
+ )
407
+ scores_filtered = tf .cast (
408
+ scores_filtered , tf .float32
409
+ )
410
+ # Apply class-agnostic NMS on boxes.
411
+ (nmsed_indices_padded , valid_detections ) = (
412
+ tf .image .non_max_suppression_padded (
413
+ boxes = boxes_filtered ,
414
+ scores = scores_filtered ,
415
+ max_output_size = max_num_detections ,
416
+ iou_threshold = nms_iou_threshold ,
417
+ pad_to_max_output_size = True ,
418
+ score_threshold = pre_nms_score_threshold ,
419
+ sorted_input = True ,
420
+ name = 'nms_detections'
421
+ )
422
+ )
423
+ nmsed_boxes = tf .gather (
424
+ boxes_filtered , nmsed_indices_padded , batch_dims = 1 , axis = 1
425
+ )
426
+ nmsed_scores = tf .gather (
427
+ scores_filtered , nmsed_indices_padded , batch_dims = 1 , axis = 1
428
+ )
429
+ nmsed_classes = tf .gather (
430
+ classes_filtered , nmsed_indices_padded , batch_dims = 1 , axis = 1
431
+ )
432
+
433
+ # Sets the padded boxes, scores, and classes to 0.
434
+ padding_mask = tf .reshape (
435
+ tf .range (max_num_detections ), [1 , - 1 ]
436
+ ) < tf .reshape (valid_detections , [- 1 , 1 ])
437
+ nmsed_boxes = nmsed_boxes * tf .cast (
438
+ tf .expand_dims (padding_mask , axis = 2 ), nmsed_boxes .dtype
439
+ )
440
+ nmsed_scores = nmsed_scores * tf .cast (padding_mask , nmsed_scores .dtype )
441
+ nmsed_classes = nmsed_classes * tf .cast (padding_mask , nmsed_classes .dtype )
442
+
443
+ return nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections
444
+
445
+
446
+ def _generate_detections_v2_class_aware (
447
+ boxes : tf .Tensor ,
448
+ scores : tf .Tensor ,
449
+ pre_nms_top_k : int = 5000 ,
450
+ pre_nms_score_threshold : float = 0.05 ,
451
+ nms_iou_threshold : float = 0.5 ,
452
+ max_num_detections : int = 100 ,
453
+ ):
454
+ """Generates the final detections by using class-aware NMS.
345
455
346
456
Args:
347
457
boxes: A `tf.Tensor` with shape `[batch_size, N, num_classes, 4]` or
@@ -419,6 +529,72 @@ def _generate_detections_v2(
419
529
return nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections
420
530
421
531
532
+ def _generate_detections_v2 (
533
+ boxes : tf .Tensor ,
534
+ scores : tf .Tensor ,
535
+ pre_nms_top_k : int = 5000 ,
536
+ pre_nms_score_threshold : float = 0.05 ,
537
+ nms_iou_threshold : float = 0.5 ,
538
+ max_num_detections : int = 100 ,
539
+ use_class_agnostic_nms : Optional [bool ] = None ,
540
+ ):
541
+ """Generates the final detections given the model outputs.
542
+
543
+ This implementation unrolls classes dimension while using the tf.while_loop
544
+ to implement the batched NMS, so that it can be parallelized at the batch
545
+ dimension. It should give better performance comparing to v1 implementation.
546
+ It is TPU compatible.
547
+
548
+ Args:
549
+ boxes: A `tf.Tensor` with shape `[batch_size, N, num_classes, 4]` or
550
+ `[batch_size, N, 1, 4]`, which box predictions on all feature levels. The
551
+ N is the number of total anchors on all levels.
552
+ scores: A `tf.Tensor` with shape `[batch_size, N, num_classes]`, which
553
+ stacks class probability on all feature levels. The N is the number of
554
+ total anchors on all levels. The num_classes is the number of classes
555
+ predicted by the model. Note that the class_outputs here is the raw score.
556
+ pre_nms_top_k: An `int` number of top candidate detections per class before
557
+ NMS.
558
+ pre_nms_score_threshold: A `float` representing the threshold for deciding
559
+ when to remove boxes based on score.
560
+ nms_iou_threshold: A `float` representing the threshold for deciding whether
561
+ boxes overlap too much with respect to IOU.
562
+ max_num_detections: A `scalar` representing maximum number of boxes retained
563
+ over all classes.
564
+ use_class_agnostic_nms: A `bool` of whether non max suppression is operated
565
+ on all the boxes using max scores across all classes.
566
+
567
+ Returns:
568
+ nms_boxes: A `float` tf.Tensor of shape [batch_size, max_num_detections, 4]
569
+ representing top detected boxes in [y1, x1, y2, x2].
570
+ nms_scores: A `float` tf.Tensor of shape [batch_size, max_num_detections]
571
+ representing sorted confidence scores for detected boxes. The values are
572
+ between [0, 1].
573
+ nms_classes: An `int` tf.Tensor of shape [batch_size, max_num_detections]
574
+ representing classes for detected boxes.
575
+ valid_detections: An `int` tf.Tensor of shape [batch_size] only the top
576
+ `valid_detections` boxes are valid detections.
577
+ """
578
+ if use_class_agnostic_nms :
579
+ return _generate_detections_v2_class_agnostic (
580
+ boxes = boxes ,
581
+ scores = scores ,
582
+ pre_nms_top_k = pre_nms_top_k ,
583
+ pre_nms_score_threshold = pre_nms_score_threshold ,
584
+ nms_iou_threshold = nms_iou_threshold ,
585
+ max_num_detections = max_num_detections ,
586
+ )
587
+
588
+ return _generate_detections_v2_class_aware (
589
+ boxes = boxes ,
590
+ scores = scores ,
591
+ pre_nms_top_k = pre_nms_top_k ,
592
+ pre_nms_score_threshold = pre_nms_score_threshold ,
593
+ nms_iou_threshold = nms_iou_threshold ,
594
+ max_num_detections = max_num_detections ,
595
+ )
596
+
597
+
422
598
def _generate_detections_v3 (
423
599
boxes : tf .Tensor ,
424
600
scores : tf .Tensor ,
@@ -957,6 +1133,7 @@ def __init__(
957
1133
pre_nms_top_k_sharding_block : Optional [int ] = None ,
958
1134
nms_v3_refinements : Optional [int ] = None ,
959
1135
return_decoded : Optional [bool ] = None ,
1136
+ use_class_agnostic_nms : Optional [bool ] = None ,
960
1137
** kwargs ,
961
1138
):
962
1139
"""Initializes a multi-level detection generator.
@@ -989,8 +1166,21 @@ def __init__(
989
1166
if == 2, AP is reduced <0.1%, AR is reduced <1% on COCO
990
1167
return_decoded: A `bool` of whether to return decoded boxes before NMS
991
1168
regardless of whether `apply_nms` is True or not.
1169
+ use_class_agnostic_nms: A `bool` of whether non max suppression is
1170
+ operated on all the boxes using max scores across all classes.
992
1171
**kwargs: Additional keyword arguments passed to Layer.
1172
+
1173
+ Raises:
1174
+ ValueError: If `use_class_agnostic_nms` is required by `nms_version` is
1175
+ not specified as `v2`.
993
1176
"""
1177
+ if use_class_agnostic_nms and nms_version != 'v2' :
1178
+ raise ValueError (
1179
+ 'If not using TFLite custom NMS, `use_class_agnostic_nms` can only be'
1180
+ ' enabled for NMS v2 for now, but NMS {} is used! If you are using'
1181
+ ' TFLite NMS, please configure TFLite custom NMS for class-agnostic'
1182
+ ' NMS.' .format (nms_version )
1183
+ )
994
1184
self ._config_dict = {
995
1185
'apply_nms' : apply_nms ,
996
1186
'pre_nms_top_k' : pre_nms_top_k ,
@@ -1001,6 +1191,7 @@ def __init__(
1001
1191
'use_cpu_nms' : use_cpu_nms ,
1002
1192
'soft_nms_sigma' : soft_nms_sigma ,
1003
1193
'return_decoded' : return_decoded ,
1194
+ 'use_class_agnostic_nms' : use_class_agnostic_nms ,
1004
1195
}
1005
1196
# Don't store if were not defined
1006
1197
if pre_nms_top_k_sharding_block is not None :
@@ -1347,6 +1538,9 @@ def __call__(
1347
1538
],
1348
1539
nms_iou_threshold = self ._config_dict ['nms_iou_threshold' ],
1349
1540
max_num_detections = self ._config_dict ['max_num_detections' ],
1541
+ use_class_agnostic_nms = self ._config_dict [
1542
+ 'use_class_agnostic_nms'
1543
+ ],
1350
1544
)
1351
1545
)
1352
1546
# Set `nmsed_attributes` to None for v2.
0 commit comments