Skip to content

Commit 782f406

Browse files
Update resize_bounding_boxes for rotated boxes
Test Plan: ```bash pytest test/test_transforms_v2.py -vvv -k "TestResize and test_kernel_bounding_boxes" pytest test/test_transforms_v2.py -vvv -k "TestResize and test_bounding_boxes_correctness" ````
1 parent 87e821f commit 782f406

File tree

2 files changed

+100
-20
lines changed

2 files changed

+100
-20
lines changed

test/test_transforms_v2.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from torchvision.transforms.functional import pil_modes_mapping, to_pil_image
5050
from torchvision.transforms.v2 import functional as F
5151
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
52-
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs
52+
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
5353
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
5454

5555

@@ -560,7 +560,9 @@ def affine_bounding_boxes(bounding_boxes):
560560
)
561561

562562

563-
def reference_affine_rotated_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
563+
def reference_affine_rotated_bounding_boxes_helper(
564+
bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True, flip=False
565+
):
564566
format = bounding_boxes.format
565567
canvas_size = new_canvas_size or bounding_boxes.canvas_size
566568

@@ -588,17 +590,20 @@ def affine_rotated_bounding_boxes(bounding_boxes):
588590
transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T)
589591
output = torch.tensor(
590592
[
591-
float(transformed_points[1, 0]),
592-
float(transformed_points[1, 1]),
593593
float(transformed_points[0, 0]),
594594
float(transformed_points[0, 1]),
595-
float(transformed_points[3, 0]),
596-
float(transformed_points[3, 1]),
595+
float(transformed_points[1, 0]),
596+
float(transformed_points[1, 1]),
597597
float(transformed_points[2, 0]),
598598
float(transformed_points[2, 1]),
599+
float(transformed_points[3, 0]),
600+
float(transformed_points[3, 1]),
599601
]
600602
)
601603

604+
output = output[[2, 3, 0, 1, 6, 7, 4, 5]] if flip else output
605+
output = _parallelogram_to_bounding_boxes(output)
606+
602607
output = F.convert_bounding_box_format(
603608
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
604609
)
@@ -707,7 +712,7 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype,
707712
check_scripted_vs_eager=not isinstance(size, int),
708713
)
709714

710-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
715+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
711716
@pytest.mark.parametrize("size", OUTPUT_SIZES)
712717
@pytest.mark.parametrize("use_max_size", [True, False])
713718
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@@ -725,6 +730,7 @@ def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device):
725730
check_kernel(
726731
F.resize_bounding_boxes,
727732
bounding_boxes,
733+
format=format,
728734
canvas_size=bounding_boxes.canvas_size,
729735
size=size,
730736
**max_size_kwarg,
@@ -816,7 +822,7 @@ def test_image_correctness(self, size, interpolation, use_max_size, fn):
816822
self._check_output_size(image, actual, size=size, **max_size_kwarg)
817823
torch.testing.assert_close(actual, expected, atol=1, rtol=0)
818824

819-
def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=None):
825+
def _reference_resize_bounding_boxes(self, bounding_boxes, format, *, size, max_size=None):
820826
old_height, old_width = bounding_boxes.canvas_size
821827
new_height, new_width = self._compute_output_size(
822828
input_size=bounding_boxes.canvas_size, size=size, max_size=max_size
@@ -832,13 +838,19 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non
832838
],
833839
)
834840

835-
return reference_affine_bounding_boxes_helper(
841+
helper = (
842+
reference_affine_rotated_bounding_boxes_helper
843+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
844+
else reference_affine_bounding_boxes_helper
845+
)
846+
847+
return helper(
836848
bounding_boxes,
837849
affine_matrix=affine_matrix,
838850
new_canvas_size=(new_height, new_width),
839851
)
840852

841-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
853+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
842854
@pytest.mark.parametrize("size", OUTPUT_SIZES)
843855
@pytest.mark.parametrize("use_max_size", [True, False])
844856
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
@@ -849,7 +861,7 @@ def test_bounding_boxes_correctness(self, format, size, use_max_size, fn):
849861
bounding_boxes = make_bounding_boxes(format=format, canvas_size=self.INPUT_SIZE)
850862

851863
actual = fn(bounding_boxes, size=size, **max_size_kwarg)
852-
expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg)
864+
expected = self._reference_resize_bounding_boxes(bounding_boxes, format=format, size=size, **max_size_kwarg)
853865

854866
self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg)
855867
torch.testing.assert_close(actual, expected)
@@ -1152,7 +1164,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B
11521164
)
11531165

11541166
helper = (
1155-
reference_affine_rotated_bounding_boxes_helper
1167+
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True)
11561168
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
11571169
else reference_affine_bounding_boxes_helper
11581170
)
@@ -1607,7 +1619,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou
16071619
)
16081620

16091621
helper = (
1610-
reference_affine_rotated_bounding_boxes_helper
1622+
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True)
16111623
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
16121624
else reference_affine_bounding_boxes_helper
16131625
)

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,53 @@ def _resize_mask_dispatch(
381381
return tv_tensors.wrap(output, like=inpt)
382382

383383

384+
def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor, inplace: bool = False) -> torch.Tensor:
385+
"""
386+
Convert a parallelogram to a rectangle while keeping the points (x1, y1) and (x3, y3) unchanged.
387+
388+
This function transforms a parallelogram represented by 8 coordinates (4 points) into a rectangle.
389+
The first point (x1, y1) and the third point (x3, y3) of the parallelogram remain fixed,
390+
while the second and fourth points are adjusted to form a proper rectangle.
391+
392+
Args:
393+
parallelogram (torch.Tensor): Tensor of shape (..., 8) containing coordinates of parallelograms.
394+
Format is [x1, y1, x2, y2, x3, y3, x4, y4].
395+
inplace (bool, optional): If True, performs operation in-place. Default is False.
396+
397+
Returns:
398+
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
399+
The output maintains the same dtype as the input.
400+
"""
401+
if not inplace:
402+
parallelogram = parallelogram.clone()
403+
404+
dtype = parallelogram.dtype
405+
if not torch.is_floating_point(parallelogram):
406+
parallelogram = parallelogram.float()
407+
408+
# Calculate diagonal vector from first to third point
409+
dx = parallelogram[..., 4] - parallelogram[..., 0]
410+
dy = parallelogram[..., 5] - parallelogram[..., 1]
411+
diag = torch.sqrt(dx**2 + dy**2)
412+
413+
# Calculate rotation angle in radians
414+
r_rad = torch.atan2(parallelogram[..., 1] - parallelogram[..., 3], parallelogram[..., 2] - parallelogram[..., 0])
415+
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
416+
417+
# Calculate width using the angle between diagonal and rotation
418+
w = diag * torch.abs(torch.sin(torch.atan2(dx, dy) - r_rad))
419+
420+
# Update coordinates to form a rectangle
421+
parallelogram[..., 2] = parallelogram[..., 0] + w * cos
422+
parallelogram[..., 3] = parallelogram[..., 1] - w * sin
423+
parallelogram[..., 6] = parallelogram[..., 4] - w * cos
424+
parallelogram[..., 7] = parallelogram[..., 5] + w * sin
425+
return parallelogram.to(dtype)
426+
427+
384428
def resize_bounding_boxes(
385429
bounding_boxes: torch.Tensor,
430+
format: tv_tensors.BoundingBoxFormat,
386431
canvas_size: tuple[int, int],
387432
size: Optional[list[int]],
388433
max_size: Optional[int] = None,
@@ -395,19 +440,42 @@ def resize_bounding_boxes(
395440

396441
w_ratio = new_width / old_width
397442
h_ratio = new_height / old_height
398-
ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device)
399-
return (
400-
bounding_boxes.mul(ratios).to(bounding_boxes.dtype),
401-
(new_height, new_width),
402-
)
443+
if tv_tensors.is_rotated_bounding_format(format):
444+
original_shape = bounding_boxes.shape
445+
xyxyxyxy_boxes = convert_bounding_box_format(
446+
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, inplace=False
447+
).reshape(-1, 8)
448+
449+
ratios = torch.tensor(
450+
[w_ratio, h_ratio, w_ratio, h_ratio, w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device
451+
)
452+
transformed_points = xyxyxyxy_boxes.mul(ratios)
453+
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points)
454+
return (
455+
convert_bounding_box_format(
456+
out_bboxes,
457+
old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY,
458+
new_format=format,
459+
inplace=False,
460+
)
461+
.to(bounding_boxes.dtype)
462+
.reshape(original_shape),
463+
(new_height, new_width),
464+
)
465+
else:
466+
ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device)
467+
return (
468+
bounding_boxes.mul(ratios).to(bounding_boxes.dtype),
469+
(new_height, new_width),
470+
)
403471

404472

405473
@_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
406474
def _resize_bounding_boxes_dispatch(
407475
inpt: tv_tensors.BoundingBoxes, size: Optional[list[int]], max_size: Optional[int] = None, **kwargs: Any
408476
) -> tv_tensors.BoundingBoxes:
409477
output, canvas_size = resize_bounding_boxes(
410-
inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
478+
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, size=size, max_size=max_size
411479
)
412480
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
413481

@@ -2204,7 +2272,7 @@ def resized_crop_bounding_boxes(
22042272
size: list[int],
22052273
) -> tuple[torch.Tensor, tuple[int, int]]:
22062274
bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
2207-
return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)
2275+
return resize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size, size=size)
22082276

22092277

22102278
@_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)

0 commit comments

Comments
 (0)