Skip to content

Commit 50ebc68

Browse files
Add model export config to RetinaNet, which optionally cast model outputs to floats, and normalize output box coordinates.
PiperOrigin-RevId: 381334851
1 parent c736968 commit 50ebc68

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

official/vision/beta/configs/retinanet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ class RetinaNet(hyperparams.Config):
130130
norm_activation: common.NormActivation = common.NormActivation()
131131

132132

133+
@dataclasses.dataclass
134+
class ExportConfig(hyperparams.Config):
135+
output_normalized_coordinates: bool = False
136+
cast_num_detections_to_float: bool = False
137+
cast_detection_classes_to_float: bool = False
138+
139+
133140
@dataclasses.dataclass
134141
class RetinaNetTask(cfg.TaskConfig):
135142
model: RetinaNet = RetinaNet()
@@ -140,6 +147,7 @@ class RetinaNetTask(cfg.TaskConfig):
140147
init_checkpoint_modules: str = 'all' # all or backbone
141148
annotation_file: Optional[str] = None
142149
per_category_metrics: bool = False
150+
export_config: ExportConfig = ExportConfig()
143151

144152

145153
@exp_factory.register_config_factory('retinanet')

official/vision/beta/serving/detection.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import tensorflow as tf
1919

20+
from cloud_tpu.models.detection.utils import box_utils
2021
from official.vision.beta import configs
2122
from official.vision.beta.modeling import factory
2223
from official.vision.beta.ops import anchor
@@ -130,6 +131,28 @@ def serve(self, images: tf.Tensor):
130131
training=False)
131132

132133
if self.params.task.model.detection_generator.apply_nms:
134+
# For RetinaNet model, apply export_config.
135+
# TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
136+
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
137+
export_config = self.params.task.export_config
138+
# Normalize detection box coordinates to [0, 1].
139+
if export_config.output_normalized_coordinates:
140+
detection_boxes = (
141+
detections['detection_boxes'] /
142+
tf.tile(image_info[:, 2:3, :], [1, 1, 2]))
143+
detections['detection_boxes'] = box_utils.normalize_boxes(
144+
detection_boxes, image_info[:, 0:1, :])
145+
146+
# Cast num_detections and detection_classes to float. This allows the
147+
# model inference to work on chain (go/chain) as chain requires floating
148+
# point outputs.
149+
if export_config.cast_num_detections_to_float:
150+
detections['num_detections'] = tf.cast(
151+
detections['num_detections'], dtype=tf.float32)
152+
if export_config.cast_detection_classes_to_float:
153+
detections['detection_classes'] = tf.cast(
154+
detections['detection_classes'], dtype=tf.float32)
155+
133156
final_outputs = {
134157
'detection_boxes': detections['detection_boxes'],
135158
'detection_scores': detections['detection_scores'],

0 commit comments

Comments
 (0)