Skip to content

Commit 7bb6a23

Browse files
committed
Internal change
PiperOrigin-RevId: 506457533
1 parent 0d85d22 commit 7bb6a23

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

official/vision/configs/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,6 @@ class TFLitePostProcessingConfig(hyperparams.Config):
145145
use_regular_nms: bool = False
146146
nms_score_threshold: float = 0.1
147147
nms_iou_threshold: float = 0.5
148+
# Whether to normalize coordinates of anchors to [0, 1]. If setting to True,
149+
# coordinates of output boxes is also normalized but latency increases.
150+
normalize_anchor_coordinates: Optional[bool] = False

official/vision/modeling/layers/detection_generator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,13 @@ def _generate_detections_tflite(raw_boxes: Mapping[str, tf.Tensor],
533533
wa = anchors[..., 3] - anchors[..., 1]
534534
anchors = tf.stack([ycenter_a, xcenter_a, ha, wa], axis=-1)
535535

536-
# TFLite's object detection APIs require normalized anchors.
537-
height, width = config['input_image_size']
538-
normalize_factor = tf.constant([height, width, height, width],
539-
dtype=tf.float32)
540-
anchors = anchors / normalize_factor
536+
if config.get('normalize_anchor_coordinates', False):
537+
# TFLite's object detection APIs require normalized anchors.
538+
height, width = config['input_image_size']
539+
normalize_factor = tf.constant(
540+
[height, width, height, width], dtype=tf.float32
541+
)
542+
anchors = anchors / normalize_factor
541543

542544
# There is no TF equivalent for TFLite's custom post-processing op.
543545
# So we add an 'empty' composite function here, that is legalized to the

official/vision/modeling/layers/detection_generator_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def testDetectionsOutputShape(self, nms_version, has_att_heads, use_cpu_nms,
150150
'nms_score_threshold': 0.01,
151151
'nms_iou_threshold': 0.5,
152152
'input_image_size': [224, 224],
153+
'normalize_anchor_coordinates': True,
153154
}
154155
kwargs = {
155156
'apply_nms': True,

0 commit comments

Comments
 (0)