Skip to content

Commit f124509

Browse files
utkuozbulakfmassa
authored andcommitted
Add reflect, symmetric and edge padding (#460)
* Added reflect, symmetric and edge padding * Updated padding docs, added tests
1 parent ca5d4db commit f124509

File tree

3 files changed

+89
-8
lines changed

3 files changed

+89
-8
lines changed

test/test_transforms.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,37 @@ def test_pad_with_tuple_of_pad_values(self):
235235
# Checking if Padding can be printed as string
236236
transforms.Pad(padding).__repr__()
237237

238+
def test_pad_with_non_constant_padding_modes(self):
239+
"""Unit tests for edge, reflect, symmetric padding"""
240+
img = torch.zeros(3, 27, 27)
241+
img[:, :, 0] = 1 # Constant value added to leftmost edge
242+
img = transforms.ToPILImage()(img)
243+
img = F.pad(img, 1, (200, 200, 200))
244+
245+
# pad 3 to all sidess
246+
edge_padded_img = F.pad(img, 3, padding_mode='edge')
247+
# First 6 elements of leftmost edge in the middle of the image, values are in order:
248+
# edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
249+
edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
250+
assert np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 255, 0]))
251+
assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35)
252+
253+
# Pad 3 to left/right, 2 to top/bottom
254+
reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect')
255+
# First 6 elements of leftmost edge in the middle of the image, values are in order:
256+
# reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
257+
reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
258+
assert np.all(reflect_middle_slice == np.asarray([0, 0, 255, 200, 255, 0]))
259+
assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35)
260+
261+
# Pad 3 to left, 2 to top, 2 to right, 1 to bottom
262+
symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric')
263+
# First 6 elements of leftmost edge in the middle of the image, values are in order:
264+
# sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
265+
symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
266+
assert np.all(symmetric_middle_slice == np.asarray([0, 255, 200, 200, 255, 0]))
267+
assert transforms.ToTensor()(symmetric_padded_img).size() == (3, 32, 34)
268+
238269
def test_pad_raises_with_invalid_pad_sequence_len(self):
239270
with self.assertRaises(ValueError):
240271
transforms.Pad(())

torchvision/transforms/functional.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def scale(*args, **kwargs):
210210
return resize(*args, **kwargs)
211211

212212

213-
def pad(img, padding, fill=0):
214-
"""Pad the given PIL Image on all sides with the given "pad" value.
213+
def pad(img, padding, fill=0, padding_mode='constant'):
214+
"""Pad the given PIL Image on all sides with speficified padding mode and fill value.
215215
216216
Args:
217217
img (PIL Image): Image to be padded.
@@ -220,8 +220,18 @@ def pad(img, padding, fill=0):
220220
on left/right and top/bottom respectively. If a tuple of length 4 is provided
221221
this is the padding for the left, top, right and bottom borders
222222
respectively.
223-
fill: Pixel fill value. Default is 0. If a tuple of
223+
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
224224
length 3, it is used to fill R, G, B channels respectively.
225+
This value is only used when the padding_mode is constant
226+
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
227+
constant: pads with a constant value, this value is specified with fill
228+
edge: pads with the last value on the edge of the image
229+
reflect: pads with reflection of image (without repeating the last value on the edge)
230+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
231+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
232+
symmetric: pads with reflection of image (repeating the last value on the edge)
233+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
234+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
225235
226236
Returns:
227237
PIL Image: Padded image.
@@ -233,12 +243,39 @@ def pad(img, padding, fill=0):
233243
raise TypeError('Got inappropriate padding arg')
234244
if not isinstance(fill, (numbers.Number, str, tuple)):
235245
raise TypeError('Got inappropriate fill arg')
246+
if not isinstance(padding_mode, str):
247+
raise TypeError('Got inappropriate padding_mode arg')
236248

237249
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
238250
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
239251
"{} element tuple".format(len(padding)))
240252

241-
return ImageOps.expand(img, border=padding, fill=fill)
253+
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
254+
'Padding mode should be either constant, edge, reflect or symmetric'
255+
256+
if padding_mode == 'constant':
257+
return ImageOps.expand(img, border=padding, fill=fill)
258+
else:
259+
if isinstance(padding, int):
260+
pad_left = pad_right = pad_top = pad_bottom = padding
261+
if isinstance(padding, collections.Sequence) and len(padding) == 2:
262+
pad_left = pad_right = padding[0]
263+
pad_top = pad_bottom = padding[1]
264+
if isinstance(padding, collections.Sequence) and len(padding) == 4:
265+
pad_left = padding[0]
266+
pad_top = padding[1]
267+
pad_right = padding[2]
268+
pad_bottom = padding[3]
269+
270+
img = np.asarray(img)
271+
# RGB image
272+
if len(img.shape) == 3:
273+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
274+
# Grayscale image
275+
if len(img.shape) == 2:
276+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
277+
278+
return Image.fromarray(img)
242279

243280

244281
def crop(img, i, j, h, w):

torchvision/transforms/transforms.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,31 @@ class Pad(object):
227227
on left/right and top/bottom respectively. If a tuple of length 4 is provided
228228
this is the padding for the left, top, right and bottom borders
229229
respectively.
230-
fill: Pixel fill value. Default is 0. If a tuple of
230+
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
231231
length 3, it is used to fill R, G, B channels respectively.
232+
This value is only used when the padding_mode is constant
233+
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
234+
constant: pads with a constant value, this value is specified with fill
235+
edge: pads with the last value at the edge of the image
236+
reflect: pads with reflection of image (without repeating the last value on the edge)
237+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
238+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
239+
symmetric: pads with reflection of image (repeating the last value on the edge)
240+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
241+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
232242
"""
233243

234-
def __init__(self, padding, fill=0):
244+
def __init__(self, padding, fill=0, padding_mode='constant'):
235245
assert isinstance(padding, (numbers.Number, tuple))
236246
assert isinstance(fill, (numbers.Number, str, tuple))
247+
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
237248
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
238249
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
239250
"{} element tuple".format(len(padding)))
240251

241252
self.padding = padding
242253
self.fill = fill
254+
self.padding_mode = padding_mode
243255

244256
def __call__(self, img):
245257
"""
@@ -249,10 +261,11 @@ def __call__(self, img):
249261
Returns:
250262
PIL Image: Padded image.
251263
"""
252-
return F.pad(img, self.padding, self.fill)
264+
return F.pad(img, self.padding, self.fill, self.padding_mode)
253265

254266
def __repr__(self):
255-
return self.__class__.__name__ + '(padding={0}, fill={1})'.format(self.padding, self.fill)
267+
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
268+
format(self.padding, self.fill, self.padding_mode)
256269

257270

258271
class Lambda(object):

0 commit comments

Comments
 (0)