|
17 | 17 |
|
18 | 18 | import tensorflow as tf
|
19 | 19 |
|
| 20 | +from cloud_tpu.models.detection.utils import box_utils |
20 | 21 | from official.vision.beta import configs
|
21 | 22 | from official.vision.beta.modeling import factory
|
22 | 23 | from official.vision.beta.ops import anchor
|
@@ -130,6 +131,28 @@ def serve(self, images: tf.Tensor):
|
130 | 131 | training=False)
|
131 | 132 |
|
132 | 133 | if self.params.task.model.detection_generator.apply_nms:
|
| 134 | + # For RetinaNet model, apply export_config. |
| 135 | + # TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed. |
| 136 | + if isinstance(self.params.task.model, configs.retinanet.RetinaNet): |
| 137 | + export_config = self.params.task.export_config |
| 138 | + # Normalize detection box coordinates to [0, 1]. |
| 139 | + if export_config.output_normalized_coordinates: |
| 140 | + detection_boxes = ( |
| 141 | + detections['detection_boxes'] / |
| 142 | + tf.tile(image_info[:, 2:3, :], [1, 1, 2])) |
| 143 | + detections['detection_boxes'] = box_utils.normalize_boxes( |
| 144 | + detection_boxes, image_info[:, 0:1, :]) |
| 145 | + |
| 146 | + # Cast num_detections and detection_classes to float. This allows the |
| 147 | + # model inference to work on chain (go/chain) as chain requires floating |
| 148 | + # point outputs. |
| 149 | + if export_config.cast_num_detections_to_float: |
| 150 | + detections['num_detections'] = tf.cast( |
| 151 | + detections['num_detections'], dtype=tf.float32) |
| 152 | + if export_config.cast_detection_classes_to_float: |
| 153 | + detections['detection_classes'] = tf.cast( |
| 154 | + detections['detection_classes'], dtype=tf.float32) |
| 155 | + |
133 | 156 | final_outputs = {
|
134 | 157 | 'detection_boxes': detections['detection_boxes'],
|
135 | 158 | 'detection_scores': detections['detection_scores'],
|
|
0 commit comments