-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathaugment.py
More file actions
423 lines (342 loc) · 13.8 KB
/
augment.py
File metadata and controls
423 lines (342 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
"""
3Augment implementation
Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
and timm DA (https://github.com/rwightman/pytorch-image-models)
"""
import math
from typing import Tuple
import torch
from torchvision import transforms
import numpy as np
import random
from PIL import ImageFilter, ImageOps
from torch import Tensor
from torchvision.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode
# ---- Custom Augs ----
class GaussianBlur(object):
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
if random.random() > self.prob:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
class Solarization(object):
def __init__(self, p=0.2):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
return img
class GrayScale(object):
def __init__(self, p=0.2):
self.p = p
self.transf = transforms.Grayscale(3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
return img
class HorizontalFlip(object):
def __init__(self, p=0.2):
self.p = p
self.transf = transforms.RandomHorizontalFlip(p=1.0)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
return img
# ---- Main Augmentation Pipeline ----
def new_data_aug_generator(args=None):
img_size = args.input_size
remove_random_resized_crop = True
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
# Primary transforms
if remove_random_resized_crop:
primary_tfl = [
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomCrop(img_size, padding=4, padding_mode='reflect'),
transforms.RandomHorizontalFlip()
]
else:
primary_tfl = [
transforms.RandomResizedCrop(
img_size, scale=(0.08, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip()
]
# Secondary transforms (color/blur/solarize/gray)
secondary_tfl = [
transforms.RandomChoice([
GrayScale(p=1.0),
Solarization(p=1.0),
GaussianBlur(p=1.0)
])
]
if args.color_jitter is not None and args.color_jitter != 0:
secondary_tfl.append(
transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)
)
# Final transforms (tensor + normalization)
final_tfl = [
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
]
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
transforms_module = get_module(use_v2)
mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes)
if use_v2
else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes)
if use_v2
else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha)
)
if not mixup_cutmix:
return None
return transforms_module.RandomChoice(mixup_cutmix)
class RandomMixUp(torch.nn.Module):
"""Randomly apply MixUp to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
if num_classes < 1:
raise ValueError(
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
)
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s
class RandomCutMix(torch.nn.Module):
"""Randomly apply CutMix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
if num_classes < 1:
raise ValueError("Please provide a valid positive value for the num_classes.")
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
_, H, W = F.get_dimensions(batch)
r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))
r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s
def get_module(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.transforms.v2
return torchvision.transforms.v2
else:
import torchvision.transforms
return torchvision.transforms
class ClassificationPresetTrain:
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter. We may change that in the
# future though, if we change the output type from the dataset.
def __init__(
self,
*,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
ra_magnitude=9,
augmix_severity=3,
random_erase_prob=0.0,
backend="pil",
use_v2=False,
):
T = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation))
if hflip_prob > 0:
transforms.append(T.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))
if backend == "pil":
transforms.append(T.PILToTensor())
transforms.extend(
[
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
transforms.append(T.RandomErasing(p=random_erase_prob))
self.transforms = T.Compose(transforms)
def __call__(self, img):
return self.transforms(img)
class ClassificationPresetEval:
def __init__(
self,
*,
crop_size,
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
backend="pil",
use_v2=False,
):
T = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
transforms += [
T.Resize(resize_size, interpolation=interpolation, antialias=True),
T.CenterCrop(crop_size),
]
if backend == "pil":
transforms.append(T.PILToTensor())
transforms += [
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
self.transforms = T.Compose(transforms)
def __call__(self, img):
return self.transforms(img)