Skip to content

Commit c9a7e0b

Browse files
Add builder that applies bounding box-specific ops for RandAugment
PiperOrigin-RevId: 421439862
1 parent 49a5706 commit c9a7e0b

File tree

4 files changed

+58
-4
lines changed

4 files changed

+58
-4
lines changed

official/vision/beta/configs/retinanet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class Parser(hyperparams.Config):
5858
skip_crowd_during_training: bool = True
5959
max_num_instances: int = 100
6060
# Can choose AutoAugment and RandAugment.
61-
# TODO(b/205346436) Support RandAugment.
6261
aug_type: Optional[common.Augmentation] = None
6362

6463
# Keep for backward compatibility. Not used.

official/vision/beta/dataloaders/retinanet_input.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self,
7575
upper-bound threshold to assign negative labels for anchors. An anchor
7676
with a score below the threshold is labeled negative.
7777
aug_type: An optional Augmentation object to choose from AutoAugment and
78-
RandAugment. The latter is not supported, and will raise ValueError.
78+
RandAugment.
7979
aug_rand_hflip: `bool`, if True, augment training with random horizontal
8080
flip.
8181
aug_scale_min: `float`, the minimum scale applied to `output_size` for
@@ -122,8 +122,16 @@ def __init__(self,
122122
augmentation_name=aug_type.autoaug.augmentation_name,
123123
cutout_const=aug_type.autoaug.cutout_const,
124124
translate_const=aug_type.autoaug.translate_const)
125+
elif aug_type.type == 'randaug':
126+
logging.info('Using RandAugment.')
127+
self._augmenter = augment.RandAugment.build_for_detection(
128+
num_layers=aug_type.randaug.num_layers,
129+
magnitude=aug_type.randaug.magnitude,
130+
cutout_const=aug_type.randaug.cutout_const,
131+
translate_const=aug_type.randaug.translate_const,
132+
prob_to_apply=aug_type.randaug.prob_to_apply,
133+
exclude_ops=aug_type.randaug.exclude_ops)
125134
else:
126-
# TODO(b/205346436) Support RandAugment.
127135
raise ValueError(f'Augmentation policy {aug_type.type} not supported.')
128136

129137
# Deprecated. Data Augmentation with AutoAugment.
@@ -162,7 +170,6 @@ def _parse_train_data(self, data):
162170
# Apply autoaug or randaug.
163171
if self._augmenter is not None:
164172
image, boxes = self._augmenter.distort_with_boxes(image, boxes)
165-
166173
image_shape = tf.shape(input=image)[0:2]
167174

168175
# Normalizes image with mean and std pixel values.

official/vision/beta/ops/augment.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,6 +1950,37 @@ def __init__(self,
19501950
op for op in self.available_ops if op not in exclude_ops
19511951
]
19521952

1953+
@classmethod
1954+
def build_for_detection(cls,
1955+
num_layers: int = 2,
1956+
magnitude: float = 10.,
1957+
cutout_const: float = 40.,
1958+
translate_const: float = 100.,
1959+
magnitude_std: float = 0.0,
1960+
prob_to_apply: Optional[float] = None,
1961+
exclude_ops: Optional[List[str]] = None):
1962+
"""Builds a RandAugment that modifies bboxes for geometric transforms."""
1963+
augmenter = cls(
1964+
num_layers=num_layers,
1965+
magnitude=magnitude,
1966+
cutout_const=cutout_const,
1967+
translate_const=translate_const,
1968+
magnitude_std=magnitude_std,
1969+
prob_to_apply=prob_to_apply,
1970+
exclude_ops=exclude_ops)
1971+
box_aware_ops_by_base_name = {
1972+
'Rotate': 'Rotate_BBox',
1973+
'ShearX': 'ShearX_BBox',
1974+
'ShearY': 'ShearY_BBox',
1975+
'TranslateX': 'TranslateX_BBox',
1976+
'TranslateY': 'TranslateY_BBox',
1977+
}
1978+
augmenter.available_ops = [
1979+
box_aware_ops_by_base_name.get(op_name) or op_name
1980+
for op_name in augmenter.available_ops
1981+
]
1982+
return augmenter
1983+
19531984
def _distort_common(
19541985
self,
19551986
image: tf.Tensor,

official/vision/beta/ops/augment_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,23 @@ def test_randaug_with_bboxes(self):
140140
self.assertEqual((224, 224, 3), aug_image.shape)
141141
self.assertEqual((2, 4), aug_bboxes.shape)
142142

143+
def test_randaug_build_for_detection(self):
144+
"""Smoke test to be sure there are no syntax errors built for detection."""
145+
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
146+
bboxes = tf.ones((2, 4), dtype=tf.float32)
147+
148+
augmenter = augment.RandAugment.build_for_detection()
149+
self.assertCountEqual(augmenter.available_ops, [
150+
'AutoContrast', 'Equalize', 'Invert', 'Posterize', 'Solarize', 'Color',
151+
'Contrast', 'Brightness', 'Sharpness', 'Cutout', 'SolarizeAdd',
152+
'Rotate_BBox', 'ShearX_BBox', 'ShearY_BBox', 'TranslateX_BBox',
153+
'TranslateY_BBox'
154+
])
155+
156+
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
157+
self.assertEqual((224, 224, 3), aug_image.shape)
158+
self.assertEqual((2, 4), aug_bboxes.shape)
159+
143160
def test_all_policy_ops(self):
144161
"""Smoke test to be sure all augmentation functions can execute."""
145162

0 commit comments

Comments
 (0)