Skip to content

Commit 056cd61

Browse files
ziyeqinghanfyangf
authored andcommitted
Normalize anchors during TFLite post-processing in object detection.
PiperOrigin-RevId: 486704284
1 parent 7de35cf commit 056cd61

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

official/vision/modeling/factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,11 @@ def build_retinanet(
306306
decoder_features = decoder(backbone_features)
307307
_ = head(decoder_features)
308308

309+
# Add `input_image_size` into `tflite_post_processing_config`.
310+
tflite_post_processing_config = generator_config.tflite_post_processing.as_dict(
311+
)
312+
tflite_post_processing_config['input_image_size'] = (input_specs.shape[1],
313+
input_specs.shape[2])
309314
detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
310315
apply_nms=generator_config.apply_nms,
311316
pre_nms_top_k=generator_config.pre_nms_top_k,
@@ -315,8 +320,7 @@ def build_retinanet(
315320
nms_version=generator_config.nms_version,
316321
use_cpu_nms=generator_config.use_cpu_nms,
317322
soft_nms_sigma=generator_config.soft_nms_sigma,
318-
tflite_post_processing_config=generator_config.tflite_post_processing
319-
.as_dict())
323+
tflite_post_processing_config=tflite_post_processing_config)
320324

321325
model = retinanet_model.RetinaNetModel(
322326
backbone,

official/vision/modeling/layers/detection_generator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,12 @@ def _generate_detections_tflite(raw_boxes: Mapping[str, tf.Tensor],
533533
wa = anchors[..., 3] - anchors[..., 1]
534534
anchors = tf.stack([ycenter_a, xcenter_a, ha, wa], axis=-1)
535535

536+
# TFLite's object detection APIs require normalized anchors.
537+
height, width = config['input_image_size']
538+
normalize_factor = tf.constant([height, width, height, width],
539+
dtype=tf.float32)
540+
anchors = anchors / normalize_factor
541+
536542
# There is no TF equivalent for TFLite's custom post-processing op.
537543
# So we add an 'empty' composite function here, that is legalized to the
538544
# custom op with MLIR.

official/vision/modeling/layers/detection_generator_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def testDetectionsOutputShape(self, nms_version, has_att_heads, use_cpu_nms,
148148
'max_classes_per_detection': 1,
149149
'use_regular_nms': use_regular_nms,
150150
'nms_score_threshold': 0.01,
151-
'nms_iou_threshold': 0.5
151+
'nms_iou_threshold': 0.5,
152+
'input_image_size': [224, 224],
152153
}
153154
kwargs = {
154155
'apply_nms': True,
@@ -253,7 +254,8 @@ def test_serialize_deserialize(self):
253254
'max_classes_per_detection': 1,
254255
'use_regular_nms': True,
255256
'nms_score_threshold': 0.01,
256-
'nms_iou_threshold': 0.5
257+
'nms_iou_threshold': 0.5,
258+
'input_image_size': [224, 224],
257259
}
258260
kwargs = {
259261
'apply_nms': True,

0 commit comments

Comments
 (0)