Skip to content

Commit 5b7cd01

Browse files
Internal change
PiperOrigin-RevId: 416886349
1 parent 782e39e commit 5b7cd01

File tree

5 files changed

+875
-63
lines changed

5 files changed

+875
-63
lines changed

official/vision/beta/configs/retinanet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,14 @@ class Parser(hyperparams.Config):
5555
aug_rand_hflip: bool = False
5656
aug_scale_min: float = 1.0
5757
aug_scale_max: float = 1.0
58-
aug_policy: Optional[str] = None
5958
skip_crowd_during_training: bool = True
6059
max_num_instances: int = 100
60+
# Can choose AutoAugment and RandAugment.
61+
# TODO(b/205346436) Support RandAugment.
62+
aug_type: Optional[common.Augmentation] = None
63+
64+
# Keep for backward compatibility. Not used.
65+
aug_policy: Optional[str] = None
6166

6267

6368
@dataclasses.dataclass

official/vision/beta/dataloaders/retinanet_input.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
"""
2020

2121
# Import libraries
22+
from absl import logging
2223
import tensorflow as tf
2324

2425
from official.vision.beta.dataloaders import parser
2526
from official.vision.beta.dataloaders import utils
2627
from official.vision.beta.ops import anchor
28+
from official.vision.beta.ops import augment
2729
from official.vision.beta.ops import box_ops
2830
from official.vision.beta.ops import preprocess_ops
2931

@@ -40,6 +42,7 @@ def __init__(self,
4042
anchor_size,
4143
match_threshold=0.5,
4244
unmatched_threshold=0.5,
45+
aug_type=None,
4346
aug_rand_hflip=False,
4447
aug_scale_min=1.0,
4548
aug_scale_max=1.0,
@@ -71,6 +74,8 @@ def __init__(self,
7174
unmatched_threshold: `float` number between 0 and 1 representing the
7275
upper-bound threshold to assign negative labels for anchors. An anchor
7376
with a score below the threshold is labeled negative.
77+
aug_type: An optional Augmentation object to choose from AutoAugment and
78+
RandAugment. The latter is not supported, and will raise ValueError.
7479
aug_rand_hflip: `bool`, if True, augment training with random horizontal
7580
flip.
7681
aug_scale_min: `float`, the minimum scale applied to `output_size` for
@@ -108,7 +113,20 @@ def __init__(self,
108113
self._aug_scale_min = aug_scale_min
109114
self._aug_scale_max = aug_scale_max
110115

111-
# Data Augmentation with AutoAugment.
116+
# Data augmentation with AutoAugment or RandAugment.
117+
self._augmenter = None
118+
if aug_type is not None:
119+
if aug_type.type == 'autoaug':
120+
logging.info('Using AutoAugment.')
121+
self._augmenter = augment.AutoAugment(
122+
augmentation_name=aug_type.autoaug.augmentation_name,
123+
cutout_const=aug_type.autoaug.cutout_const,
124+
translate_const=aug_type.autoaug.translate_const)
125+
else:
126+
# TODO(b/205346436) Support RandAugment.
127+
raise ValueError(f'Augmentation policy {aug_type.type} not supported.')
128+
129+
# Deprecated. Data Augmentation with AutoAugment.
112130
self._use_autoaugment = use_autoaugment
113131
self._autoaugment_policy_name = autoaugment_policy_name
114132

@@ -138,9 +156,13 @@ def _parse_train_data(self, data):
138156
for k, v in attributes.items():
139157
attributes[k] = tf.gather(v, indices)
140158

141-
# Gets original image and its size.
159+
# Gets original image.
142160
image = data['image']
143161

162+
# Apply autoaug or randaug.
163+
if self._augmenter is not None:
164+
image, boxes = self._augmenter.distort_with_boxes(image, boxes)
165+
144166
image_shape = tf.shape(input=image)[0:2]
145167

146168
# Normalizes image with mean and std pixel values.

0 commit comments

Comments
 (0)