Skip to content

Commit 6497852

Browse files
daavooalykhantejani
authored andcommitted
add rotate and RandomRotation to transforms (#303)
add rotate and RandomRotation to transforms
1 parent 7e61f8d commit 6497852

File tree

3 files changed

+135
-1
lines changed

3 files changed

+135
-1
lines changed

test/test_transforms.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,56 @@ def test_linear_transformation(self):
664664
cov = np.dot(xwhite, xwhite.T) / x.size(0)
665665
assert np.allclose(cov, np.identity(1), rtol=1e-3)
666666

667+
def test_rotate(self):
668+
x = np.zeros((100, 100, 3), dtype=np.uint8)
669+
x[40, 40] = [255, 255, 255]
670+
671+
with self.assertRaises(TypeError):
672+
F.rotate(x, 10)
673+
674+
img = F.to_pil_image(x)
675+
676+
result = F.rotate(img, 45)
677+
assert result.size == (100, 100)
678+
r, c, ch = np.where(result)
679+
assert all(x in r for x in [49, 50])
680+
assert all(x in c for x in [36])
681+
assert all(x in ch for x in [0, 1, 2])
682+
683+
result = F.rotate(img, 45, expand=True)
684+
assert result.size == (142, 142)
685+
r, c, ch = np.where(result)
686+
assert all(x in r for x in [70, 71])
687+
assert all(x in c for x in [57])
688+
assert all(x in ch for x in [0, 1, 2])
689+
690+
result = F.rotate(img, 45, center=(40, 40))
691+
assert result.size == (100, 100)
692+
r, c, ch = np.where(result)
693+
assert all(x in r for x in [40])
694+
assert all(x in c for x in [40])
695+
assert all(x in ch for x in [0, 1, 2])
696+
697+
result_a = F.rotate(img, 90)
698+
result_b = F.rotate(img, -270)
699+
700+
assert np.all(np.array(result_a) == np.array(result_b))
701+
702+
def test_random_rotation(self):
703+
704+
with self.assertRaises(ValueError):
705+
transforms.RandomRotation(-0.7)
706+
transforms.RandomRotation([-0.7])
707+
transforms.RandomRotation([-0.7, 0, 0.7])
708+
709+
t = transforms.RandomRotation(10)
710+
angle = t.get_params(t.degrees)
711+
assert angle > -10 and angle < 10
712+
713+
t = transforms.RandomRotation((-10, 10))
714+
angle = t.get_params(t.degrees)
715+
assert angle > -10 and angle < 10
716+
667717

668718
if __name__ == '__main__':
669719
unittest.main()

torchvision/transforms/functional.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,28 @@ def adjust_gamma(img, gamma, gain=1):
524524

525525
img = Image.fromarray(np_img, 'RGB').convert(input_mode)
526526
return img
527+
528+
529+
def rotate(img, angle, resample=False, expand=False, center=None):
530+
"""Rotate the image by angle and then (optionally) translate it by (n_columns, n_rows)
531+
532+
533+
Args:
534+
img (PIL Image): PIL Image to be rotated.
535+
angle ({float, int}): In degrees degrees counter clockwise order.
536+
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
537+
An optional resampling filter.
538+
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
539+
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
540+
expand (bool, optional): Optional expansion flag.
541+
If true, expands the output image to make it large enough to hold the entire rotated image.
542+
If false or omitted, make the output image the same size as the input image.
543+
Note that the expand flag assumes rotation around the center and no translation.
544+
center (2-tuple, optional): Optional center of rotation.
545+
Origin is the upper left corner.
546+
Default is the center of the image.
547+
"""
548+
if not _is_pil_image(img):
549+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
550+
551+
return img.rotate(angle, resample, expand, center)

torchvision/transforms/transforms.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
1919
"Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
20-
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter"]
20+
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation"]
2121

2222

2323
class Compose(object):
@@ -570,3 +570,62 @@ def __call__(self, img):
570570
transform = self.get_params(self.brightness, self.contrast,
571571
self.saturation, self.hue)
572572
return transform(img)
573+
574+
575+
class RandomRotation(object):
576+
"""Rotate the image by angle.
577+
578+
Args:
579+
degrees (sequence or float or int): Range of degrees to select from.
580+
If degrees is a number instead of sequence like (min, max), the range of degrees
581+
will be (-degrees, +degrees).
582+
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
583+
An optional resampling filter.
584+
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
585+
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
586+
expand (bool, optional): Optional expansion flag.
587+
If true, expands the output to make it large enough to hold the entire rotated image.
588+
If false or omitted, make the output image the same size as the input image.
589+
Note that the expand flag assumes rotation around the center and no translation.
590+
center (2-tuple, optional): Optional center of rotation.
591+
Origin is the upper left corner.
592+
Default is the center of the image.
593+
"""
594+
595+
def __init__(self, degrees, resample=False, expand=False, center=None):
596+
if isinstance(degrees, numbers.Number):
597+
if degrees < 0:
598+
raise ValueError("If degrees is a single number, it must be positive.")
599+
self.degrees = (-degrees, degrees)
600+
else:
601+
if len(degrees) != 2:
602+
raise ValueError("If degrees is a sequence, it must be of len 2.")
603+
self.degrees = degrees
604+
605+
self.resample = resample
606+
self.expand = expand
607+
self.center = center
608+
609+
@staticmethod
610+
def get_params(degrees):
611+
"""Get parameters for ``rotate`` for a random rotation.
612+
613+
Returns:
614+
sequence: params to be passed to ``rotate`` for random rotation.
615+
"""
616+
angle = np.random.uniform(degrees[0], degrees[1])
617+
618+
return angle
619+
620+
def __call__(self, img):
621+
"""
622+
Args:
623+
img (PIL Image): Image to be rotated.
624+
625+
Returns:
626+
PIL Image: Rotated image.
627+
"""
628+
629+
angle = self.get_params(self.degrees)
630+
631+
return F.rotate(img, angle, self.resample, self.expand, self.center)

0 commit comments

Comments
 (0)