Skip to content

Commit 91b4459

Browse files
authored
Add Perspective fill option (#1973)
* Add fill option to RandomPerspective #1972 * Minor fix to docstring syntax * Add _parse_fill() to get fillcolor (#1972) * Minor refactoring as per comments. * Added test for RandomPerspective with fillcolor. * Force perspective transform in test.
1 parent e61b68e commit 91b4459

File tree

3 files changed

+85
-25
lines changed

3 files changed

+85
-25
lines changed

test/test_transforms.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,41 @@ def test_randomperspective(self):
177177
self.assertGreater(torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3,
178178
torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
179179

180+
def test_randomperspective_fill(self):
181+
height = 100
182+
width = 100
183+
img = torch.ones(3, height, width)
184+
to_pil_image = transforms.ToPILImage()
185+
img = to_pil_image(img)
186+
187+
modes = ("L", "RGB", "F")
188+
nums_bands = [len(mode) for mode in modes]
189+
fill = 127
190+
191+
for mode, num_bands in zip(modes, nums_bands):
192+
img_conv = img.convert(mode)
193+
perspective = transforms.RandomPerspective(p=1, fill=fill)
194+
tr_img = perspective(img_conv)
195+
pixel = tr_img.getpixel((0, 0))
196+
197+
if not isinstance(pixel, tuple):
198+
pixel = (pixel,)
199+
self.assertTupleEqual(pixel, tuple([fill] * num_bands))
200+
201+
for mode, num_bands in zip(modes, nums_bands):
202+
img_conv = img.convert(mode)
203+
startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
204+
tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
205+
pixel = tr_img.getpixel((0, 0))
206+
207+
if not isinstance(pixel, tuple):
208+
pixel = (pixel,)
209+
self.assertTupleEqual(pixel, tuple([fill] * num_bands))
210+
211+
for wrong_num_bands in set(nums_bands) - {num_bands}:
212+
with self.assertRaises(ValueError):
213+
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
214+
180215
def test_resize(self):
181216
height = random.randint(24, 32) * 2
182217
width = random.randint(24, 32) * 2

torchvision/transforms/functional.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,41 @@ def hflip(img):
425425
return img.transpose(Image.FLIP_LEFT_RIGHT)
426426

427427

428+
def _parse_fill(fill, img, min_pil_version):
429+
"""Helper function to get the fill color for rotate and perspective transforms.
430+
431+
Args:
432+
fill (n-tuple or int or float): Pixel fill value for area outside the transformed
433+
image. If int or float, the value is used for all bands respectively.
434+
Defaults to 0 for all bands.
435+
img (PIL Image): Image to be filled.
436+
min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option
437+
was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0)
438+
439+
Returns:
440+
dict: kwarg for ``fillcolor``
441+
"""
442+
if PILLOW_VERSION < min_pil_version:
443+
if fill is None:
444+
return {}
445+
else:
446+
msg = ("The option to fill background area of the transformed image, "
447+
"requires pillow>={}")
448+
raise RuntimeError(msg.format(min_pil_version))
449+
450+
num_bands = len(img.getbands())
451+
if fill is None:
452+
fill = 0
453+
if isinstance(fill, (int, float)) and num_bands > 1:
454+
fill = tuple([fill] * num_bands)
455+
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
456+
msg = ("The number of elements in 'fill' does not match the number of "
457+
"bands of the image ({} != {})")
458+
raise ValueError(msg.format(len(fill), num_bands))
459+
460+
return {"fillcolor": fill}
461+
462+
428463
def _get_perspective_coeffs(startpoints, endpoints):
429464
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
430465
@@ -450,22 +485,29 @@ def _get_perspective_coeffs(startpoints, endpoints):
450485
return res.squeeze_(1).tolist()
451486

452487

453-
def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC):
488+
def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=None):
454489
"""Perform perspective transform of the given PIL Image.
455490
456491
Args:
457492
img (PIL Image): Image to be transformed.
458493
startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image
459494
endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
460495
interpolation: Default- Image.BICUBIC
496+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
497+
image. If int or float, the value is used for all bands respectively.
498+
This option is only available for ``pillow>=5.0.0``.
499+
461500
Returns:
462501
PIL Image: Perspectively transformed Image.
463502
"""
503+
464504
if not _is_pil_image(img):
465505
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
466506

507+
opts = _parse_fill(fill, img, '5.0.0')
508+
467509
coeffs = _get_perspective_coeffs(startpoints, endpoints)
468-
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation)
510+
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
469511

470512

471513
def vflip(img):
@@ -721,30 +763,10 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
721763
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
722764
723765
"""
724-
def parse_fill(fill, num_bands):
725-
if PILLOW_VERSION < "5.2.0":
726-
if fill is None:
727-
return {}
728-
else:
729-
msg = ("The option to fill background area of the rotated image, "
730-
"requires pillow>=5.2.0")
731-
raise RuntimeError(msg)
732-
733-
if fill is None:
734-
fill = 0
735-
if isinstance(fill, (int, float)) and num_bands > 1:
736-
fill = tuple([fill] * num_bands)
737-
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
738-
msg = ("The number of elements in 'fill' does not match the number of "
739-
"bands of the image ({} != {})")
740-
raise ValueError(msg.format(len(fill), num_bands))
741-
742-
return {"fillcolor": fill}
743-
744766
if not _is_pil_image(img):
745767
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
746768

747-
opts = parse_fill(fill, len(img.getbands()))
769+
opts = _parse_fill(fill, img, '5.2.0')
748770

749771
return img.rotate(angle, resample, expand, center, **opts)
750772

torchvision/transforms/transforms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,12 +550,15 @@ class RandomPerspective(object):
550550
551551
distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
552552
553+
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
554+
If int, it is used for all channels respectively. Default value is 0.
553555
"""
554556

555-
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
557+
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0):
556558
self.p = p
557559
self.interpolation = interpolation
558560
self.distortion_scale = distortion_scale
561+
self.fill = fill
559562

560563
def __call__(self, img):
561564
"""
@@ -571,7 +574,7 @@ def __call__(self, img):
571574
if random.random() < self.p:
572575
width, height = img.size
573576
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
574-
return F.perspective(img, startpoints, endpoints, self.interpolation)
577+
return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
575578
return img
576579

577580
@staticmethod

0 commit comments

Comments
 (0)