@@ -404,7 +404,7 @@ def __init__(self,
404
404
pre_nms_score_threshold : float = 0.05 ,
405
405
nms_iou_threshold : float = 0.5 ,
406
406
max_num_detections : int = 100 ,
407
- use_batched_nms : bool = False ,
407
+ nms_version : str = 'v2' ,
408
408
use_cpu_nms : bool = False ,
409
409
** kwargs ):
410
410
"""Initializes a detection generator.
@@ -420,8 +420,7 @@ def __init__(self,
420
420
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
421
421
max_num_detections: An `int` of the final number of total detections to
422
422
generate.
423
- use_batched_nms: A `bool` of whether or not use
424
- `tf.image.combined_non_max_suppression`.
423
+ nms_version: A string of `batched`, `v1` or `v2` specifies NMS version.
425
424
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
426
425
**kwargs: Additional keyword arguments passed to Layer.
427
426
"""
@@ -431,7 +430,7 @@ def __init__(self,
431
430
'pre_nms_score_threshold' : pre_nms_score_threshold ,
432
431
'nms_iou_threshold' : nms_iou_threshold ,
433
432
'max_num_detections' : max_num_detections ,
434
- 'use_batched_nms ' : use_batched_nms ,
433
+ 'nms_version ' : nms_version ,
435
434
'use_cpu_nms' : use_cpu_nms ,
436
435
}
437
436
super (DetectionGenerator , self ).__init__ (** kwargs )
@@ -524,14 +523,14 @@ def __call__(self,
524
523
nms_context = contextlib .nullcontext ()
525
524
526
525
with nms_context :
527
- if self ._config_dict ['use_batched_nms' ] :
526
+ if self ._config_dict ['nms_version' ] == 'batched' :
528
527
(nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections ) = (
529
528
_generate_detections_batched (
530
529
decoded_boxes , box_scores ,
531
530
self ._config_dict ['pre_nms_score_threshold' ],
532
531
self ._config_dict ['nms_iou_threshold' ],
533
532
self ._config_dict ['max_num_detections' ]))
534
- else :
533
+ elif self . _config_dict [ 'nms_version' ] == 'v1' :
535
534
(nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections , _ ) = (
536
535
_generate_detections_v1 (
537
536
decoded_boxes ,
@@ -541,6 +540,19 @@ def __call__(self,
541
540
._config_dict ['pre_nms_score_threshold' ],
542
541
nms_iou_threshold = self ._config_dict ['nms_iou_threshold' ],
543
542
max_num_detections = self ._config_dict ['max_num_detections' ]))
543
+ elif self ._config_dict ['nms_version' ] == 'v2' :
544
+ (nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections ) = (
545
+ _generate_detections_v2 (
546
+ decoded_boxes ,
547
+ box_scores ,
548
+ pre_nms_top_k = self ._config_dict ['pre_nms_top_k' ],
549
+ pre_nms_score_threshold = self
550
+ ._config_dict ['pre_nms_score_threshold' ],
551
+ nms_iou_threshold = self ._config_dict ['nms_iou_threshold' ],
552
+ max_num_detections = self ._config_dict ['max_num_detections' ]))
553
+ else :
554
+ raise ValueError ('NMS version {} not supported.' .format (
555
+ self ._config_dict ['nms_version' ]))
544
556
545
557
# Adds 1 to offset the background class which has index 0.
546
558
nmsed_classes += 1
@@ -570,7 +582,7 @@ def __init__(self,
570
582
pre_nms_score_threshold : float = 0.05 ,
571
583
nms_iou_threshold : float = 0.5 ,
572
584
max_num_detections : int = 100 ,
573
- use_batched_nms : bool = False ,
585
+ nms_version : str = 'v1' ,
574
586
use_cpu_nms : bool = False ,
575
587
** kwargs ):
576
588
"""Initializes a multi-level detection generator.
@@ -586,8 +598,7 @@ def __init__(self,
586
598
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
587
599
max_num_detections: An `int` of the final number of total detections to
588
600
generate.
589
- use_batched_nms: A `bool` of whether or not use
590
- `tf.image.combined_non_max_suppression`.
601
+ nms_version: A string of `batched`, `v1` or `v2` specifies NMS version
591
602
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
592
603
**kwargs: Additional keyword arguments passed to Layer.
593
604
"""
@@ -597,7 +608,7 @@ def __init__(self,
597
608
'pre_nms_score_threshold' : pre_nms_score_threshold ,
598
609
'nms_iou_threshold' : nms_iou_threshold ,
599
610
'max_num_detections' : max_num_detections ,
600
- 'use_batched_nms ' : use_batched_nms ,
611
+ 'nms_version ' : nms_version ,
601
612
'use_cpu_nms' : use_cpu_nms ,
602
613
}
603
614
super (MultilevelDetectionGenerator , self ).__init__ (** kwargs )
@@ -731,19 +742,19 @@ def __call__(self,
731
742
nms_context = contextlib .nullcontext ()
732
743
733
744
with nms_context :
734
- if self ._config_dict ['use_batched_nms' ] :
735
- if raw_attributes :
736
- raise ValueError (
737
- 'Attribute learning is not supported for batched NMS.' )
738
-
745
+ if raw_attributes and ( self ._config_dict ['nms_version' ] != 'v1' ) :
746
+ raise ValueError (
747
+ 'Attribute learning is only supported for NMSv1 but NMS {} is used.'
748
+ . format ( self . _config_dict [ 'nms_version' ]) )
749
+ if self . _config_dict [ 'nms_version' ] == 'batched' :
739
750
(nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections ) = (
740
751
_generate_detections_batched (
741
752
boxes , scores , self ._config_dict ['pre_nms_score_threshold' ],
742
753
self ._config_dict ['nms_iou_threshold' ],
743
754
self ._config_dict ['max_num_detections' ]))
744
755
# Set `nmsed_attributes` to None for batched NMS.
745
756
nmsed_attributes = {}
746
- else :
757
+ elif self . _config_dict [ 'nms_version' ] == 'v1' :
747
758
(nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections ,
748
759
nmsed_attributes ) = (
749
760
_generate_detections_v1 (
@@ -755,6 +766,21 @@ def __call__(self,
755
766
._config_dict ['pre_nms_score_threshold' ],
756
767
nms_iou_threshold = self ._config_dict ['nms_iou_threshold' ],
757
768
max_num_detections = self ._config_dict ['max_num_detections' ]))
769
+ elif self ._config_dict ['nms_version' ] == 'v2' :
770
+ (nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections ) = (
771
+ _generate_detections_v2 (
772
+ boxes ,
773
+ scores ,
774
+ pre_nms_top_k = self ._config_dict ['pre_nms_top_k' ],
775
+ pre_nms_score_threshold = self
776
+ ._config_dict ['pre_nms_score_threshold' ],
777
+ nms_iou_threshold = self ._config_dict ['nms_iou_threshold' ],
778
+ max_num_detections = self ._config_dict ['max_num_detections' ]))
779
+ # Set `nmsed_attributes` to None for v2.
780
+ nmsed_attributes = {}
781
+ else :
782
+ raise ValueError ('NMS version {} not supported.' .format (
783
+ self ._config_dict ['nms_version' ]))
758
784
759
785
# Adds 1 to offset the background class which has index 0.
760
786
nmsed_classes += 1
0 commit comments