Skip to content

Commit dbe3927

Browse files
xianzhidutensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 397809846
1 parent abc4fc0 commit dbe3927

File tree

9 files changed

+63
-38
lines changed

9 files changed

+63
-38
lines changed

official/vision/beta/configs/maskrcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class DetectionGenerator(hyperparams.Config):
131131
pre_nms_score_threshold: float = 0.05
132132
nms_iou_threshold: float = 0.5
133133
max_num_detections: int = 100
134-
use_batched_nms: bool = False
134+
nms_version: str = 'v2' # `v2`, `v1`, `batched`
135135
use_cpu_nms: bool = False
136136

137137

official/vision/beta/configs/retinanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class DetectionGenerator(hyperparams.Config):
112112
pre_nms_score_threshold: float = 0.05
113113
nms_iou_threshold: float = 0.5
114114
max_num_detections: int = 100
115-
use_batched_nms: bool = False
115+
nms_version: str = 'v2' # `v2`, `v1`, `batched`.
116116
use_cpu_nms: bool = False
117117

118118

official/vision/beta/modeling/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def build_maskrcnn(
197197
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
198198
nms_iou_threshold=generator_config.nms_iou_threshold,
199199
max_num_detections=generator_config.max_num_detections,
200-
use_batched_nms=generator_config.use_batched_nms,
200+
nms_version=generator_config.nms_version,
201201
use_cpu_nms=generator_config.use_cpu_nms)
202202

203203
if model_config.include_mask:
@@ -300,7 +300,7 @@ def build_retinanet(
300300
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
301301
nms_iou_threshold=generator_config.nms_iou_threshold,
302302
max_num_detections=generator_config.max_num_detections,
303-
use_batched_nms=generator_config.use_batched_nms,
303+
nms_version=generator_config.nms_version,
304304
use_cpu_nms=generator_config.use_cpu_nms)
305305

306306
model = retinanet_model.RetinaNetModel(

official/vision/beta/modeling/layers/detection_generator.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def __init__(self,
404404
pre_nms_score_threshold: float = 0.05,
405405
nms_iou_threshold: float = 0.5,
406406
max_num_detections: int = 100,
407-
use_batched_nms: bool = False,
407+
nms_version: str = 'v2',
408408
use_cpu_nms: bool = False,
409409
**kwargs):
410410
"""Initializes a detection generator.
@@ -420,8 +420,7 @@ def __init__(self,
420420
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
421421
max_num_detections: An `int` of the final number of total detections to
422422
generate.
423-
use_batched_nms: A `bool` of whether or not use
424-
`tf.image.combined_non_max_suppression`.
423+
nms_version: A string of `batched`, `v1` or `v2` specifies NMS version.
425424
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
426425
**kwargs: Additional keyword arguments passed to Layer.
427426
"""
@@ -431,7 +430,7 @@ def __init__(self,
431430
'pre_nms_score_threshold': pre_nms_score_threshold,
432431
'nms_iou_threshold': nms_iou_threshold,
433432
'max_num_detections': max_num_detections,
434-
'use_batched_nms': use_batched_nms,
433+
'nms_version': nms_version,
435434
'use_cpu_nms': use_cpu_nms,
436435
}
437436
super(DetectionGenerator, self).__init__(**kwargs)
@@ -524,14 +523,14 @@ def __call__(self,
524523
nms_context = contextlib.nullcontext()
525524

526525
with nms_context:
527-
if self._config_dict['use_batched_nms']:
526+
if self._config_dict['nms_version'] == 'batched':
528527
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
529528
_generate_detections_batched(
530529
decoded_boxes, box_scores,
531530
self._config_dict['pre_nms_score_threshold'],
532531
self._config_dict['nms_iou_threshold'],
533532
self._config_dict['max_num_detections']))
534-
else:
533+
elif self._config_dict['nms_version'] == 'v1':
535534
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, _) = (
536535
_generate_detections_v1(
537536
decoded_boxes,
@@ -541,6 +540,19 @@ def __call__(self,
541540
._config_dict['pre_nms_score_threshold'],
542541
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
543542
max_num_detections=self._config_dict['max_num_detections']))
543+
elif self._config_dict['nms_version'] == 'v2':
544+
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
545+
_generate_detections_v2(
546+
decoded_boxes,
547+
box_scores,
548+
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
549+
pre_nms_score_threshold=self
550+
._config_dict['pre_nms_score_threshold'],
551+
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
552+
max_num_detections=self._config_dict['max_num_detections']))
553+
else:
554+
raise ValueError('NMS version {} not supported.'.format(
555+
self._config_dict['nms_version']))
544556

545557
# Adds 1 to offset the background class which has index 0.
546558
nmsed_classes += 1
@@ -570,7 +582,7 @@ def __init__(self,
570582
pre_nms_score_threshold: float = 0.05,
571583
nms_iou_threshold: float = 0.5,
572584
max_num_detections: int = 100,
573-
use_batched_nms: bool = False,
585+
nms_version: str = 'v1',
574586
use_cpu_nms: bool = False,
575587
**kwargs):
576588
"""Initializes a multi-level detection generator.
@@ -586,8 +598,7 @@ def __init__(self,
586598
nms_iou_threshold: A `float` in [0, 1], the NMS IoU threshold.
587599
max_num_detections: An `int` of the final number of total detections to
588600
generate.
589-
use_batched_nms: A `bool` of whether or not use
590-
`tf.image.combined_non_max_suppression`.
601+
nms_version: A string of `batched`, `v1` or `v2` specifies NMS version
591602
use_cpu_nms: A `bool` of whether or not enforce NMS to run on CPU.
592603
**kwargs: Additional keyword arguments passed to Layer.
593604
"""
@@ -597,7 +608,7 @@ def __init__(self,
597608
'pre_nms_score_threshold': pre_nms_score_threshold,
598609
'nms_iou_threshold': nms_iou_threshold,
599610
'max_num_detections': max_num_detections,
600-
'use_batched_nms': use_batched_nms,
611+
'nms_version': nms_version,
601612
'use_cpu_nms': use_cpu_nms,
602613
}
603614
super(MultilevelDetectionGenerator, self).__init__(**kwargs)
@@ -731,19 +742,19 @@ def __call__(self,
731742
nms_context = contextlib.nullcontext()
732743

733744
with nms_context:
734-
if self._config_dict['use_batched_nms']:
735-
if raw_attributes:
736-
raise ValueError(
737-
'Attribute learning is not supported for batched NMS.')
738-
745+
if raw_attributes and (self._config_dict['nms_version'] != 'v1'):
746+
raise ValueError(
747+
'Attribute learning is only supported for NMSv1 but NMS {} is used.'
748+
.format(self._config_dict['nms_version']))
749+
if self._config_dict['nms_version'] == 'batched':
739750
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
740751
_generate_detections_batched(
741752
boxes, scores, self._config_dict['pre_nms_score_threshold'],
742753
self._config_dict['nms_iou_threshold'],
743754
self._config_dict['max_num_detections']))
744755
# Set `nmsed_attributes` to None for batched NMS.
745756
nmsed_attributes = {}
746-
else:
757+
elif self._config_dict['nms_version'] == 'v1':
747758
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections,
748759
nmsed_attributes) = (
749760
_generate_detections_v1(
@@ -755,6 +766,21 @@ def __call__(self,
755766
._config_dict['pre_nms_score_threshold'],
756767
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
757768
max_num_detections=self._config_dict['max_num_detections']))
769+
elif self._config_dict['nms_version'] == 'v2':
770+
(nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = (
771+
_generate_detections_v2(
772+
boxes,
773+
scores,
774+
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
775+
pre_nms_score_threshold=self
776+
._config_dict['pre_nms_score_threshold'],
777+
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
778+
max_num_detections=self._config_dict['max_num_detections']))
779+
# Set `nmsed_attributes` to None for v2.
780+
nmsed_attributes = {}
781+
else:
782+
raise ValueError('NMS version {} not supported.'.format(
783+
self._config_dict['nms_version']))
758784

759785
# Adds 1 to offset the background class which has index 0.
760786
nmsed_classes += 1

official/vision/beta/modeling/layers/detection_generator_test.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class DetectionGeneratorTest(
4444
parameterized.TestCase, tf.test.TestCase):
4545

4646
@parameterized.product(
47-
use_batched_nms=[True, False], use_cpu_nms=[True, False])
48-
def testDetectionsOutputShape(self, use_batched_nms, use_cpu_nms):
47+
nms_version=['batched', 'v1', 'v2'], use_cpu_nms=[True, False])
48+
def testDetectionsOutputShape(self, nms_version, use_cpu_nms):
4949
max_num_detections = 100
5050
num_classes = 4
5151
pre_nms_top_k = 5000
@@ -57,7 +57,7 @@ def testDetectionsOutputShape(self, use_batched_nms, use_cpu_nms):
5757
'pre_nms_score_threshold': pre_nms_score_threshold,
5858
'nms_iou_threshold': 0.5,
5959
'max_num_detections': max_num_detections,
60-
'use_batched_nms': use_batched_nms,
60+
'nms_version': nms_version,
6161
'use_cpu_nms': use_cpu_nms,
6262
}
6363
generator = detection_generator.DetectionGenerator(**kwargs)
@@ -97,7 +97,7 @@ def test_serialize_deserialize(self):
9797
'pre_nms_score_threshold': 0.1,
9898
'nms_iou_threshold': 0.5,
9999
'max_num_detections': 10,
100-
'use_batched_nms': False,
100+
'nms_version': 'v2',
101101
'use_cpu_nms': False,
102102
}
103103
generator = detection_generator.DetectionGenerator(**kwargs)
@@ -116,15 +116,14 @@ class MultilevelDetectionGeneratorTest(
116116
parameterized.TestCase, tf.test.TestCase):
117117

118118
@parameterized.parameters(
119-
(True, False, True),
120-
(True, False, False),
121-
(False, False, True),
122-
(False, False, False),
123-
(False, True, True),
124-
(False, True, False),
119+
('batched', False, True),
120+
('batched', False, False),
121+
('v2', False, True),
122+
('v2', False, False),
123+
('v1', True, True),
124+
('v1', True, False),
125125
)
126-
def testDetectionsOutputShape(self, use_batched_nms, has_att_heads,
127-
use_cpu_nms):
126+
def testDetectionsOutputShape(self, nms_version, has_att_heads, use_cpu_nms):
128127
min_level = 4
129128
max_level = 6
130129
num_scales = 2
@@ -142,7 +141,7 @@ def testDetectionsOutputShape(self, use_batched_nms, has_att_heads,
142141
'pre_nms_score_threshold': pre_nms_score_threshold,
143142
'nms_iou_threshold': 0.5,
144143
'max_num_detections': max_num_detections,
145-
'use_batched_nms': use_batched_nms,
144+
'nms_version': nms_version,
146145
'use_cpu_nms': use_cpu_nms,
147146
}
148147

@@ -223,7 +222,7 @@ def test_serialize_deserialize(self):
223222
'pre_nms_score_threshold': 0.1,
224223
'nms_iou_threshold': 0.5,
225224
'max_num_detections': 10,
226-
'use_batched_nms': False,
225+
'nms_version': 'v2',
227226
'use_cpu_nms': False,
228227
}
229228
generator = detection_generator.MultilevelDetectionGenerator(**kwargs)

official/vision/beta/modeling/retinanet_model_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_forward(self, strategy, image_size, training, has_att_heads,
193193
attribute_heads=attribute_heads,
194194
num_anchors_per_location=num_anchors_per_location)
195195
generator = detection_generator.MultilevelDetectionGenerator(
196-
max_num_detections=10)
196+
max_num_detections=10, nms_version='v1')
197197
model = retinanet_model.RetinaNetModel(
198198
backbone=backbone,
199199
decoder=decoder,

official/vision/beta/projects/deepmac_maskrcnn/serving/detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _build_model(self):
2828

2929
if self._batch_size is None:
3030
ValueError("batch_size can't be None for detection models")
31-
if not self.params.task.model.detection_generator.use_batched_nms:
31+
if self.params.task.model.detection_generator.nms_version != 'batched':
3232
ValueError('Only batched_nms is supported.')
3333
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
3434
self._input_image_size + [3])

official/vision/beta/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
120120
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
121121
nms_iou_threshold=generator_config.nms_iou_threshold,
122122
max_num_detections=generator_config.max_num_detections,
123-
use_batched_nms=generator_config.use_batched_nms)
123+
nms_version=generator_config.nms_version)
124124

125125
if model_config.include_mask:
126126
mask_head = deep_instance_heads.DeepMaskHead(

official/vision/beta/serving/detection_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
3333
def _get_detection_module(self, experiment_name):
3434
params = exp_factory.get_exp_config(experiment_name)
3535
params.task.model.backbone.resnet.model_id = 18
36-
params.task.model.detection_generator.use_batched_nms = True
36+
params.task.model.detection_generator.nms_version = 'batched'
3737
detection_module = detection.DetectionModule(
3838
params, batch_size=1, input_image_size=[640, 640])
3939
return detection_module

0 commit comments

Comments
 (0)