diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4bb18cf6b48..d4523891d99 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -316,6 +316,7 @@ Others v2.RandomHorizontalFlip v2.RandomVerticalFlip v2.Pad + v2.PadToSquare v2.RandomZoomOut v2.RandomRotation v2.RandomAffine diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e16c0677c9f..ffd203f846e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3928,6 +3928,44 @@ def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn): assert_equal(actual, expected) +class TestPadToSquare: + @pytest.mark.parametrize( + "image", + [ + (make_image((3, 10), device="cpu", dtype=torch.uint8)), + (make_image((10, 3), device="cpu", dtype=torch.uint8)), + (make_image((10, 10), device="cpu", dtype=torch.uint8)), + ], + ) + def test__get_params(self, image): + transform = transforms.PadToSquare() + params = transform._get_params([image]) + + assert "padding" in params + padding = params["padding"] + + assert len(padding) == 4 + assert all(p >= 0 for p in padding) + + height, width = F.get_size(image) + assert max(height, width) == height + padding[1] + padding[3] + assert max(height, width) == width + padding[0] + padding[2] + + @pytest.mark.parametrize( + "image, expected_output_shape", + [ + (make_image((3, 10), device="cpu", dtype=torch.uint8), [10, 10]), + (make_image((10, 3), device="cpu", dtype=torch.uint8), [10, 10]), + (make_image((10, 10), device="cpu", dtype=torch.uint8), [10, 10]), + ], + ) + def test_pad_square_correctness(self, image, expected_output_shape): + transform = transforms.PadToSquare() + output = transform(image) + + assert F.get_size(output) == expected_output_shape + + class TestCenterCrop: INPUT_SIZE = (17, 11) OUTPUT_SIZES = [(3, 5), (5, 3), (4, 4), (21, 9), (13, 15), (19, 14), 3, (4,), [5], INPUT_SIZE] diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 2d66917b6ea..61470d2ac85 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -26,6 +26,7 @@ ElasticTransform, FiveCrop, Pad, + PadToSquare, RandomAffine, RandomCrop, RandomHorizontalFlip, diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 5d6b1841d7f..744db5a22ef 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -488,6 +488,84 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] +class PadToSquare(Transform): + """Pad a non-square input to make it square by padding the shorter side to match the longer side. + + Args: + fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. + Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + Fill value can be also a dictionary mapping data type to the fill value, e.g. + ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and + ``Mask`` will be filled with 0. + padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is "constant". + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + Example: + >>> import torch + >>> from torchvision.transforms.v2 import PadToSquare + >>> rectangular_image = torch.randint(0, 255, (3, 224, 168), dtype=torch.uint8) + >>> transform = PadToSquare(padding_mode='constant', fill=0) + >>> square_image = transform(rectangular_image) + >>> print(square_image.size()) + torch.Size([3, 224, 224]) + """ + + def __init__( + self, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ): + super().__init__() + + _check_padding_mode_arg(padding_mode) + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("`padding_mode` must be one of 'constant', 'edge', 'reflect' or 'symmetric'.") + self.padding_mode = padding_mode + self.fill = _setup_fill_arg(fill) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + # Get the original height and width from the inputs + orig_height, orig_width = query_size(flat_inputs) + + # Find the target size (maximum of height and width) + target_size = max(orig_height, orig_width) + + if orig_height < target_size: + # Need to pad height + pad_height = target_size - orig_height + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = 0 + pad_right = 0 + else: + # Need to pad width + pad_width = target_size - orig_width + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + pad_top = 0 + pad_bottom = 0 + + # The padding needs to be in the format [left, top, right, bottom] + return dict(padding=[pad_left, pad_top, pad_right, pad_bottom]) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = _get_fill(self.fill, type(inpt)) + return self._call_kernel(F.pad, inpt, padding=params["padding"], padding_mode=self.padding_mode, fill=fill) + + class RandomZoomOut(_RandomApplyTransform): """ "Zoom out" transformation from `"SSD: Single Shot MultiBox Detector" `_.