19
19
"""
20
20
21
21
# Import libraries
22
+ from absl import logging
22
23
import tensorflow as tf
23
24
24
25
from official .vision .beta .dataloaders import parser
25
26
from official .vision .beta .dataloaders import utils
26
27
from official .vision .beta .ops import anchor
28
+ from official .vision .beta .ops import augment
27
29
from official .vision .beta .ops import box_ops
28
30
from official .vision .beta .ops import preprocess_ops
29
31
@@ -40,6 +42,7 @@ def __init__(self,
40
42
anchor_size ,
41
43
match_threshold = 0.5 ,
42
44
unmatched_threshold = 0.5 ,
45
+ aug_type = None ,
43
46
aug_rand_hflip = False ,
44
47
aug_scale_min = 1.0 ,
45
48
aug_scale_max = 1.0 ,
@@ -71,6 +74,8 @@ def __init__(self,
71
74
unmatched_threshold: `float` number between 0 and 1 representing the
72
75
upper-bound threshold to assign negative labels for anchors. An anchor
73
76
with a score below the threshold is labeled negative.
77
+ aug_type: An optional Augmentation object to choose from AutoAugment and
78
+ RandAugment. The latter is not supported, and will raise ValueError.
74
79
aug_rand_hflip: `bool`, if True, augment training with random horizontal
75
80
flip.
76
81
aug_scale_min: `float`, the minimum scale applied to `output_size` for
@@ -108,7 +113,20 @@ def __init__(self,
108
113
self ._aug_scale_min = aug_scale_min
109
114
self ._aug_scale_max = aug_scale_max
110
115
111
- # Data Augmentation with AutoAugment.
116
+ # Data augmentation with AutoAugment or RandAugment.
117
+ self ._augmenter = None
118
+ if aug_type is not None :
119
+ if aug_type .type == 'autoaug' :
120
+ logging .info ('Using AutoAugment.' )
121
+ self ._augmenter = augment .AutoAugment (
122
+ augmentation_name = aug_type .autoaug .augmentation_name ,
123
+ cutout_const = aug_type .autoaug .cutout_const ,
124
+ translate_const = aug_type .autoaug .translate_const )
125
+ else :
126
+ # TODO(b/205346436) Support RandAugment.
127
+ raise ValueError (f'Augmentation policy { aug_type .type } not supported.' )
128
+
129
+ # Deprecated. Data Augmentation with AutoAugment.
112
130
self ._use_autoaugment = use_autoaugment
113
131
self ._autoaugment_policy_name = autoaugment_policy_name
114
132
@@ -138,9 +156,13 @@ def _parse_train_data(self, data):
138
156
for k , v in attributes .items ():
139
157
attributes [k ] = tf .gather (v , indices )
140
158
141
- # Gets original image and its size .
159
+ # Gets original image.
142
160
image = data ['image' ]
143
161
162
+ # Apply autoaug or randaug.
163
+ if self ._augmenter is not None :
164
+ image , boxes = self ._augmenter .distort_with_boxes (image , boxes )
165
+
144
166
image_shape = tf .shape (input = image )[0 :2 ]
145
167
146
168
# Normalizes image with mean and std pixel values.
0 commit comments