Skip to content

Commit 5a08ff8

Browse files
Internal change
PiperOrigin-RevId: 531182533
1 parent 03083cd commit 5a08ff8

File tree

1 file changed

+53
-13
lines changed

1 file changed

+53
-13
lines changed

official/vision/dataloaders/retinanet_input.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
into (image, labels) tuple for RetinaNet.
1919
"""
2020

21+
from typing import Optional
22+
2123
# Import libraries
24+
2225
from absl import logging
2326
import tensorflow as tf
2427

@@ -51,6 +54,7 @@ def __init__(self,
5154
skip_crowd_during_training=True,
5255
max_num_instances=100,
5356
dtype='bfloat16',
57+
resize_first: Optional[bool] = None,
5458
mode=None):
5559
"""Initializes parameters for parsing annotations in the dataset.
5660
@@ -91,6 +95,8 @@ def __init__(self,
9195
max_num_instances: `int` number of maximum number of instances in an
9296
image. The groundtruth data will be padded to `max_num_instances`.
9397
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
98+
resize_first: Optional `bool`, if True, resize the image before the
99+
augmentations; computationally more efficient.
94100
mode: a ModeKeys. Specifies if this is training, evaluation, prediction or
95101
prediction with ground-truths in the outputs.
96102
"""
@@ -141,6 +147,30 @@ def __init__(self,
141147
# Data type.
142148
self._dtype = dtype
143149

150+
# Input pipeline optimization.
151+
self._resize_first = resize_first
152+
153+
def _resize_and_crop_image_and_boxes(self, image, boxes, pad=True):
154+
"""Resizes and crops image and boxes, optionally with padding."""
155+
# Resizes and crops image.
156+
padded_size = None
157+
if pad:
158+
padded_size = preprocess_ops.compute_padded_size(self._output_size,
159+
2**self._max_level)
160+
image, image_info = preprocess_ops.resize_and_crop_image(
161+
image,
162+
self._output_size,
163+
padded_size=padded_size,
164+
aug_scale_min=self._aug_scale_min,
165+
aug_scale_max=self._aug_scale_max)
166+
167+
# Resizes and crops boxes.
168+
image_scale = image_info[2, :]
169+
offset = image_info[3, :]
170+
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_scale,
171+
image_info[1, :], offset)
172+
return image, boxes, image_info
173+
144174
def _parse_train_data(self, data, anchor_labeler=None):
145175
"""Parses data for training and evaluation."""
146176
classes = data['groundtruth_classes']
@@ -165,6 +195,21 @@ def _parse_train_data(self, data, anchor_labeler=None):
165195

166196
# Gets original image.
167197
image = data['image']
198+
image_size = tf.cast(tf.shape(image)[0:2], tf.float32)
199+
200+
less_output_pixels = (
201+
self._output_size[0] * self._output_size[1]
202+
) < image_size[0] * image_size[1]
203+
204+
# Resizing first can reduce augmentation computation if the original image
205+
# has more pixels than the desired output image.
206+
# There might be a smarter threshold to compute less_output_pixels as
207+
# we keep the padding to the very end, i.e., a resized image likely has less
208+
# pixels than self._output_size[0] * self._output_size[1].
209+
resize_first = self._resize_first and less_output_pixels
210+
if resize_first:
211+
image, boxes, image_info = self._resize_and_crop_image_and_boxes(
212+
image, boxes, pad=False)
168213

169214
# Apply autoaug or randaug.
170215
if self._augmenter is not None:
@@ -181,21 +226,16 @@ def _parse_train_data(self, data, anchor_labeler=None):
181226
# Converts boxes from normalized coordinates to pixel coordinates.
182227
boxes = box_ops.denormalize_boxes(boxes, image_shape)
183228

184-
# Resizes and crops image.
185-
image, image_info = preprocess_ops.resize_and_crop_image(
186-
image,
187-
self._output_size,
188-
padded_size=preprocess_ops.compute_padded_size(self._output_size,
189-
2**self._max_level),
190-
aug_scale_min=self._aug_scale_min,
191-
aug_scale_max=self._aug_scale_max)
229+
if not resize_first:
230+
image, boxes, image_info = self._resize_and_crop_image_and_boxes(
231+
image, boxes, pad=True)
232+
else:
233+
padded_size = preprocess_ops.compute_padded_size(self._output_size,
234+
2**self._max_level)
235+
image = tf.image.pad_to_bounding_box(
236+
image, 0, 0, padded_size[0], padded_size[1])
192237
image_height, image_width, _ = image.get_shape().as_list()
193238

194-
# Resizes and crops boxes.
195-
image_scale = image_info[2, :]
196-
offset = image_info[3, :]
197-
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_scale,
198-
image_info[1, :], offset)
199239
# Filters out ground-truth boxes that are all zeros.
200240
indices = box_ops.get_non_empty_box_indices(boxes)
201241
boxes = tf.gather(boxes, indices)

0 commit comments

Comments
 (0)