Skip to content

Commit 482ec55

Browse files
fyangftensorflower-gardener
authored andcommitted
Support all NMS versions besides tflite for Keras QAT RetinaNet model.
PiperOrigin-RevId: 519777521
1 parent 0182948 commit 482ec55

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

official/projects/qat/vision/modeling/factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def build_qat_retinanet(
201201
if quantization.quantize_detection_head:
202202
# Call the model with dummy input to build the head part.
203203
dummpy_input = tf.zeros([1] + model_config.input_size)
204-
optimized_model(dummpy_input, training=True)
204+
height, width, _ = model_config.input_size
205+
image_shape = [[height, width]]
206+
optimized_model.call(dummpy_input, image_shape=image_shape, training=False)
205207
helper.copy_original_weights(model.head, optimized_model.head)
206208
return optimized_model
207209

official/projects/qat/vision/serving/export_module.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
"""Export modules for QAT model serving/inference."""
16-
from absl import logging
1716
import tensorflow as tf
1817

1918
from official.projects.qat.vision.modeling import factory as qat_factory
@@ -50,11 +49,6 @@ class DetectionModule(detection.DetectionModule):
5049
"""Detection Module."""
5150

5251
def _build_model(self):
53-
if self.params.task.model.detection_generator.nms_version != 'tflite':
54-
self.params.task.model.detection_generator.nms_version = 'tflite'
55-
logging.info('Set `nms_version` to `tflite` because only TFLite NMS is '
56-
'supported for QAT detection models.')
57-
5852
model = super()._build_model()
5953

6054
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):

official/vision/modeling/layers/detection_generator.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,10 @@ def _decode_multilevel_outputs_and_pre_nms_top_k(
13101310
levels = list(raw_boxes.keys())
13111311
min_level = int(min(levels))
13121312
max_level = int(max(levels))
1313-
clip_shape = tf.expand_dims(tf.expand_dims(image_shape, axis=1), axis=1)
1313+
if image_shape is not None:
1314+
clip_shape = tf.expand_dims(tf.expand_dims(image_shape, axis=1), axis=1)
1315+
else:
1316+
clip_shape = None
13141317
for i in range(max_level, min_level - 1, -1):
13151318
(
13161319
batch_size,
@@ -1330,13 +1333,15 @@ def _decode_multilevel_outputs_and_pre_nms_top_k(
13301333
unsharded_w * num_anchors_per_locations,
13311334
4,
13321335
]
1333-
decoded_boxes = box_ops.clip_boxes(
1334-
box_ops.decode_boxes(
1335-
tf.reshape(raw_boxes[str(i)], boxes_shape),
1336-
tf.reshape(anchor_boxes[str(i)], boxes_shape),
1337-
),
1338-
clip_shape,
1336+
decoded_boxes = box_ops.decode_boxes(
1337+
tf.reshape(raw_boxes[str(i)], boxes_shape),
1338+
tf.reshape(anchor_boxes[str(i)], boxes_shape),
13391339
)
1340+
if clip_shape is not None:
1341+
decoded_boxes = box_ops.clip_boxes(
1342+
decoded_boxes,
1343+
clip_shape,
1344+
)
13401345
for raw_scores_i, decoded_boxes_i in edgetpu.shard_tensors(
13411346
1, block, (raw_scores[str(i)], decoded_boxes)
13421347
):

0 commit comments

Comments
 (0)