Skip to content

Commit b904a93

Browse files
committed
Implemented PadToSquare. Added Docs & Tests.
1 parent e9a3213 commit b904a93

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

docs/source/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ Others
316316
v2.RandomHorizontalFlip
317317
v2.RandomVerticalFlip
318318
v2.Pad
319+
v2.PadToSquare
319320
v2.RandomZoomOut
320321
v2.RandomRotation
321322
v2.RandomAffine

test/test_transforms_v2.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3928,6 +3928,44 @@ def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn):
39283928
assert_equal(actual, expected)
39293929

39303930

3931+
class TestPadToSquare:
3932+
@pytest.mark.parametrize(
3933+
"image",
3934+
[
3935+
(make_image((3, 10), device="cpu", dtype=torch.uint8)),
3936+
(make_image((10, 3), device="cpu", dtype=torch.uint8)),
3937+
(make_image((10, 10), device="cpu", dtype=torch.uint8)),
3938+
],
3939+
)
3940+
def test__get_params(self, image):
3941+
transform = transforms.PadToSquare()
3942+
params = transform._get_params([image])
3943+
3944+
assert "padding" in params
3945+
padding = params["padding"]
3946+
3947+
assert len(padding) == 4
3948+
assert all(p >= 0 for p in padding)
3949+
3950+
height, width = F.get_size(image)
3951+
assert max(height, width) == height + padding[1] + padding[3]
3952+
assert max(height, width) == width + padding[0] + padding[2]
3953+
3954+
@pytest.mark.parametrize(
3955+
"image, expected_output_shape",
3956+
[
3957+
(make_image((3, 10), device="cpu", dtype=torch.uint8), [10, 10]),
3958+
(make_image((10, 3), device="cpu", dtype=torch.uint8), [10, 10]),
3959+
(make_image((10, 10), device="cpu", dtype=torch.uint8), [10, 10]),
3960+
],
3961+
)
3962+
def test_pad_square_correctness(self, image, expected_output_shape):
3963+
transform = transforms.PadToSquare()
3964+
output = transform(image)
3965+
3966+
assert F.get_size(output) == expected_output_shape
3967+
3968+
39313969
class TestCenterCrop:
39323970
INPUT_SIZE = (17, 11)
39333971
OUTPUT_SIZES = [(3, 5), (5, 3), (4, 4), (21, 9), (13, 15), (19, 14), 3, (4,), [5], INPUT_SIZE]

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ElasticTransform,
2727
FiveCrop,
2828
Pad,
29+
PadToSquare,
2930
RandomAffine,
3031
RandomCrop,
3132
RandomHorizontalFlip,

torchvision/transforms/v2/_geometry.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,84 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
488488
return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
489489

490490

491+
class PadToSquare(Transform):
492+
"""Pad a non-square input to make it square by padding the shorter side to match the longer side.
493+
494+
Args:
495+
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
496+
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
497+
Fill value can be also a dictionary mapping data type to the fill value, e.g.
498+
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
499+
``Mask`` will be filled with 0.
500+
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
501+
Default is "constant".
502+
503+
- constant: pads with a constant value, this value is specified with fill
504+
505+
- edge: pads with the last value at the edge of the image.
506+
507+
- reflect: pads with reflection of image without repeating the last value on the edge.
508+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
509+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
510+
511+
- symmetric: pads with reflection of image repeating the last value on the edge.
512+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
513+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
514+
515+
Example:
516+
>>> import torch
517+
>>> from torchvision.transforms.v2 import PadToSquare
518+
>>> rectangular_image = torch.randint(0, 255, (3, 224, 168), dtype=torch.uint8)
519+
>>> transform = PadToSquare(padding_mode='constant', fill=0)
520+
>>> square_image = transform(rectangular_image)
521+
>>> print(square_image.size())
522+
torch.Size([3, 224, 224])
523+
"""
524+
525+
def __init__(
526+
self,
527+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
528+
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
529+
):
530+
super().__init__()
531+
532+
_check_padding_mode_arg(padding_mode)
533+
534+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
535+
raise ValueError("`padding_mode` must be one of 'constant', 'edge', 'reflect' or 'symmetric'.")
536+
self.padding_mode = padding_mode
537+
self.fill = _setup_fill_arg(fill)
538+
539+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
540+
# Get the original height and width from the inputs
541+
orig_height, orig_width = query_size(flat_inputs)
542+
543+
# Find the target size (maximum of height and width)
544+
target_size = max(orig_height, orig_width)
545+
546+
if orig_height < target_size:
547+
# Need to pad height
548+
pad_height = target_size - orig_height
549+
pad_top = pad_height // 2
550+
pad_bottom = pad_height - pad_top
551+
pad_left = 0
552+
pad_right = 0
553+
else:
554+
# Need to pad width
555+
pad_width = target_size - orig_width
556+
pad_left = pad_width // 2
557+
pad_right = pad_width - pad_left
558+
pad_top = 0
559+
pad_bottom = 0
560+
561+
# The padding needs to be in the format [left, top, right, bottom]
562+
return dict(padding=[pad_left, pad_top, pad_right, pad_bottom])
563+
564+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
565+
fill = _get_fill(self.fill, type(inpt))
566+
return self._call_kernel(F.pad, inpt, padding=params["padding"], padding_mode=self.padding_mode, fill=fill)
567+
568+
491569
class RandomZoomOut(_RandomApplyTransform):
492570
""" "Zoom out" transformation from
493571
`"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.

0 commit comments

Comments
 (0)