Skip to content

Commit c96676a

Browse files
Fix tests for int rotated boxes
1 parent 0629b64 commit c96676a

File tree

3 files changed

+55
-29
lines changed

3 files changed

+55
-29
lines changed

test/common_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,13 @@ def sample_position(values, max_value):
417417
format = tv_tensors.BoundingBoxFormat[format]
418418

419419
dtype = dtype or torch.float32
420+
int_dtype = dtype in (
421+
torch.uint8,
422+
torch.int8,
423+
torch.int16,
424+
torch.int32,
425+
torch.int64,
426+
)
420427

421428
h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size)
422429
y = sample_position(h, canvas_size[0])
@@ -443,17 +450,17 @@ def sample_position(values, max_value):
443450
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
444451
r_rad = r * torch.pi / 180.0
445452
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
446-
x1, y1 = x, y
447-
x2 = x1 + w * cos
448-
y2 = y1 - w * sin
449-
x3 = x2 + h * sin
450-
y3 = y2 + h * cos
451-
x4 = x1 + h * sin
452-
y4 = y1 + h * cos
453+
x1 = torch.round(x) if int_dtype else x
454+
y1 = torch.round(y) if int_dtype else y
455+
x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos
456+
y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin
457+
x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin
458+
y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos
459+
x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin
460+
y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos
453461
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
454462
else:
455463
raise ValueError(f"Format {format} is not supported")
456-
457464
return tv_tensors.BoundingBoxes(
458465
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
459466
)

test/test_transforms_v2.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5939,6 +5939,15 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t
59395939
assert out_label == label
59405940

59415941

5942+
@pytest.mark.parametrize("input_size", [(17, 11), (11, 17), (11, 11)])
5943+
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
5944+
@pytest.mark.parametrize("device", cpu_and_cuda())
5945+
def test_parallelogram_to_bounding_boxes(input_size, dtype, device):
5946+
bounding_boxes = make_bounding_boxes(input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, dtype=dtype, device=device)
5947+
actual = _parallelogram_to_bounding_boxes(bounding_boxes)
5948+
torch.testing.assert_close(actual, bounding_boxes, rtol=0, atol=1)
5949+
5950+
59425951
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
59435952
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
59445953
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -381,29 +381,32 @@ def _resize_mask_dispatch(
381381
return tv_tensors.wrap(output, like=inpt)
382382

383383

384-
def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor, inplace: bool = False) -> torch.Tensor:
384+
def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tensor:
385385
"""
386386
Convert a parallelogram to a rectangle while keeping the points (x1, y1) and (x3, y3) unchanged.
387387
388388
This function transforms a parallelogram represented by 8 coordinates (4 points) into a rectangle.
389389
The first point (x1, y1) and the third point (x3, y3) of the parallelogram remain fixed,
390390
while the second and fourth points are adjusted to form a proper rectangle.
391391
392+
Note:
393+
This function is not applied in-place and will return a copy of the input tensor.
394+
392395
Args:
393396
parallelogram (torch.Tensor): Tensor of shape (..., 8) containing coordinates of parallelograms.
394397
Format is [x1, y1, x2, y2, x3, y3, x4, y4].
395-
inplace (bool, optional): If True, performs operation in-place. Default is False.
396398
397399
Returns:
398400
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
399401
The output maintains the same dtype as the input.
400402
"""
401-
if not inplace:
402-
parallelogram = parallelogram.clone()
403-
404403
dtype = parallelogram.dtype
405-
if not torch.is_floating_point(parallelogram):
406-
parallelogram = parallelogram.float()
404+
int_dtype = dtype in (torch.uint8,
405+
torch.int8,
406+
torch.int16,
407+
torch.int32,
408+
torch.int64,
409+
)
407410

408411
# Calculate diagonal vector from first to third point
409412
dx = parallelogram[..., 4] - parallelogram[..., 0]
@@ -417,21 +420,28 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor, inplace: bool
417420
# Calculate width using the angle between diagonal and rotation
418421
w = diag * torch.abs(torch.sin(torch.atan2(dx, dy) - r_rad))
419422

423+
delta_x = torch.round(w * cos).to(dtype) if int_dtype else w * cos
424+
detla_y = torch.round(w * sin).to(dtype) if int_dtype else w * sin
425+
420426
# Update coordinates to form a rectangle
421-
parallelogram[..., 2] = parallelogram[..., 0] + w * cos
422-
parallelogram[..., 3] = parallelogram[..., 1] - w * sin
423-
parallelogram[..., 6] = parallelogram[..., 4] - w * cos
424-
parallelogram[..., 7] = parallelogram[..., 5] + w * sin
425-
return parallelogram.to(dtype)
427+
parallelogram[..., 2] = parallelogram[..., 0] + delta_x
428+
parallelogram[..., 3] = parallelogram[..., 1] - detla_y
429+
parallelogram[..., 6] = parallelogram[..., 4] - delta_x
430+
parallelogram[..., 7] = parallelogram[..., 5] + detla_y
431+
return parallelogram
426432

427433

428434
def resize_bounding_boxes(
429435
bounding_boxes: torch.Tensor,
430-
format: tv_tensors.BoundingBoxFormat,
431436
canvas_size: tuple[int, int],
432437
size: Optional[list[int]],
433438
max_size: Optional[int] = None,
439+
format: tv_tensors.BoundingBoxFormat = tv_tensors.BoundingBoxFormat.XYXY,
434440
) -> tuple[torch.Tensor, tuple[int, int]]:
441+
# We set the default format as `tv_tensors.BoundingBoxFormat.XYXY`
442+
# to ensure backward compatibility.
443+
# Indeed before the introduction of rotated bounding box format
444+
# this function did not received `format` parameter as input.
435445
old_height, old_width = canvas_size
436446
new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
437447

@@ -893,12 +903,9 @@ def _affine_bounding_boxes_with_expand(
893903
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
894904
dtype = bounding_boxes.dtype
895905
device = bounding_boxes.device
896-
intermediate_format = (
897-
tv_tensors.BoundingBoxFormat.XYXYXYXY
898-
if tv_tensors.is_rotated_bounding_format(format)
899-
else tv_tensors.BoundingBoxFormat.XYXY
900-
)
901-
intermediate_shape = 8 if tv_tensors.is_rotated_bounding_format(format) else 4
906+
is_rotated = tv_tensors.is_rotated_bounding_format(format)
907+
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
908+
intermediate_shape = 8 if is_rotated else 4
902909
bounding_boxes = (
903910
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format, inplace=True)
904911
).reshape(-1, intermediate_shape)
@@ -925,7 +932,7 @@ def _affine_bounding_boxes_with_expand(
925932
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
926933
# Single point structure is similar to
927934
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
928-
if tv_tensors.is_rotated_bounding_format(format):
935+
if is_rotated:
929936
points = bounding_boxes.reshape(-1, 2)
930937
else:
931938
points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
@@ -934,7 +941,7 @@ def _affine_bounding_boxes_with_expand(
934941
transformed_points = torch.matmul(points, transposed_affine_matrix)
935942
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
936943
# and compute bounding box from 4 transformed points:
937-
if tv_tensors.is_rotated_bounding_format(format):
944+
if is_rotated:
938945
transformed_points = transformed_points.reshape(-1, 8)
939946
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points)
940947
else:
@@ -1557,6 +1564,9 @@ def crop_bounding_boxes(
15571564
bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
15581565
canvas_size = (height, width)
15591566

1567+
if format == tv_tensors.BoundingBoxFormat.XYXYXYXY:
1568+
bounding_boxes = _parallelogram_to_bounding_boxes(bounding_boxes)
1569+
15601570
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
15611571

15621572

0 commit comments

Comments
 (0)