Skip to content

Commit ba64d65

Browse files
Fast rotation for right angles (#8295)
Co-authored-by: Thien Tran <[email protected]>
1 parent c7bcfad commit ba64d65

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

test/test_transforms_v2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,17 @@ def test_transform_unknown_fill_error(self):
17821782
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
17831783
transforms.RandomAffine(degrees=0, fill="fill")
17841784

1785+
@pytest.mark.parametrize("size", [(11, 17), (16, 16)])
1786+
@pytest.mark.parametrize("angle", [0, 90, 180, 270])
1787+
@pytest.mark.parametrize("expand", [False, True])
1788+
def test_functional_image_fast_path_correctness(self, size, angle, expand):
1789+
image = make_image(size, dtype=torch.uint8, device="cpu")
1790+
1791+
actual = F.rotate(image, angle=angle, expand=expand)
1792+
expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle, expand=expand))
1793+
1794+
torch.testing.assert_close(actual, expected)
1795+
17851796

17861797
class TestContainerTransforms:
17871798
class BuiltinTransform(transforms.Transform):

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,21 @@ def rotate_image(
997997
center: Optional[List[float]] = None,
998998
fill: _FillTypeJIT = None,
999999
) -> torch.Tensor:
1000+
angle = angle % 360 # shift angle to [0, 360) range
1001+
1002+
# fast path: transpose without affine transform
1003+
if center is None:
1004+
if angle == 0:
1005+
return image.clone()
1006+
if angle == 180:
1007+
return torch.rot90(image, k=2, dims=(-2, -1))
1008+
1009+
if expand or image.shape[-1] == image.shape[-2]:
1010+
if angle == 90:
1011+
return torch.rot90(image, k=1, dims=(-2, -1))
1012+
if angle == 270:
1013+
return torch.rot90(image, k=3, dims=(-2, -1))
1014+
10001015
interpolation = _check_interpolation(interpolation)
10011016

10021017
input_height, input_width = image.shape[-2:]

0 commit comments

Comments
 (0)