Skip to content
Merged
23 changes: 8 additions & 15 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,6 @@ def sample_position(values, max_value):
format = tv_tensors.BoundingBoxFormat[format]

dtype = dtype or torch.float32
int_dtype = dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)

h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size)
y = sample_position(h, canvas_size[0])
Expand All @@ -457,14 +450,14 @@ def sample_position(values, max_value):
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
r_rad = r * torch.pi / 180.0
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
x1 = torch.round(x) if int_dtype else x
y1 = torch.round(y) if int_dtype else y
x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos
y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin
x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin
y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos
x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin
y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos
x1 = x
y1 = y
x2 = x1 + w * cos
y2 = y1 - w * sin
x3 = x2 + h * sin
y3 = y2 + h * cos
x4 = x1 + h * sin
y4 = y1 + h * cos
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
else:
raise ValueError(f"Format {format} is not supported")
Expand Down
51 changes: 33 additions & 18 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,6 @@ def reference_affine_rotated_bounding_boxes_helper(

def affine_rotated_bounding_boxes(bounding_boxes):
dtype = bounding_boxes.dtype
int_dtype = dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
device = bounding_boxes.device

# Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1
Expand Down Expand Up @@ -605,17 +598,12 @@ def affine_rotated_bounding_boxes(bounding_boxes):
)

output = output[[2, 3, 0, 1, 6, 7, 4, 5]] if flip else output
if not int_dtype:
output = _parallelogram_to_bounding_boxes(output)
output = _parallelogram_to_bounding_boxes(output)

output = F.convert_bounding_box_format(
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
)

if torch.is_floating_point(output) and int_dtype:
# It is important to round before cast.
output = torch.round(output)

# For rotated boxes, it is important to cast before clamping.
return (
F.clamp_bounding_boxes(
Expand Down Expand Up @@ -760,6 +748,8 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype,
def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device):
if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
return
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")

bounding_boxes = make_bounding_boxes(
format=format,
Expand Down Expand Up @@ -1212,6 +1202,8 @@ def test_kernel_image(self, dtype, device):
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
check_kernel(
F.horizontal_flip_bounding_boxes,
Expand Down Expand Up @@ -1441,6 +1433,8 @@ def test_kernel_image(self, param, value, dtype, device):
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
self._check_kernel(
F.affine_bounding_boxes,
Expand Down Expand Up @@ -1823,6 +1817,8 @@ def test_kernel_image(self, dtype, device):
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
check_kernel(
F.vertical_flip_bounding_boxes,
Expand Down Expand Up @@ -2021,6 +2017,8 @@ def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
kwargs = {param: value}
if param != "angle":
kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"]
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")

bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)

Expand Down Expand Up @@ -3236,6 +3234,8 @@ def test_kernel_image(self, param, value, dtype, device):
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)

check_kernel(
Expand Down Expand Up @@ -3399,6 +3399,8 @@ def test_kernel_image(self, kwargs, dtype, device):
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, kwargs, format, dtype, device):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
check_kernel(F.crop_bounding_boxes, bounding_boxes, format=format, **kwargs)

Expand Down Expand Up @@ -3576,6 +3578,8 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)

actual = F.crop(bounding_boxes, **kwargs)
Expand All @@ -3590,6 +3594,8 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, device, seed):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
input_size = [s * 2 for s in output_size]
bounding_boxes = make_bounding_boxes(input_size, format=format, dtype=dtype, device=device)

Expand Down Expand Up @@ -4267,6 +4273,10 @@ def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn_type", ["functional", "transform"])
def test_correctness(self, old_format, new_format, dtype, device, fn_type):
if not dtype.is_floating_point and (
tv_tensors.is_rotated_bounding_format(old_format) or tv_tensors.is_rotated_bounding_format(new_format)
):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device)

if fn_type == "functional":
Expand Down Expand Up @@ -4706,6 +4716,8 @@ def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)

actual = fn(bounding_boxes, padding=padding)
Expand Down Expand Up @@ -4876,6 +4888,8 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
def test_bounding_boxes_correctness(self, output_size, format, dtype, device, fn):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)

actual = fn(bounding_boxes, output_size)
Expand Down Expand Up @@ -5242,6 +5256,8 @@ def perspective_bounding_boxes(bounding_boxes):
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)

actual = F.perspective(bounding_boxes, startpoints=startpoints, endpoints=endpoints)
Expand Down Expand Up @@ -5511,6 +5527,8 @@ class TestClampBoundingBoxes:
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, format, clamping_mode, dtype, device):
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
pytest.xfail("Rotated bounding boxes should be floating point tensors")
bounding_boxes = make_bounding_boxes(format=format, clamping_mode=clamping_mode, dtype=dtype, device=device)
check_kernel(
F.clamp_bounding_boxes,
Expand Down Expand Up @@ -6938,14 +6956,11 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t


@pytest.mark.parametrize("input_size", [(17, 11), (11, 17), (11, 11)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_parallelogram_to_bounding_boxes(input_size, dtype, device):
def test_parallelogram_to_bounding_boxes(input_size, device):
# Assert that applying `_parallelogram_to_bounding_boxes` to rotated boxes
# does not modify the input.
bounding_boxes = make_bounding_boxes(
input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, dtype=dtype, device=device
)
bounding_boxes = make_bounding_boxes(input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, device=device)
actual = _parallelogram_to_bounding_boxes(bounding_boxes)
torch.testing.assert_close(actual, bounding_boxes, rtol=0, atol=1)

Expand Down
19 changes: 2 additions & 17 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,19 +462,6 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
The output maintains the same dtype as the input.
"""
dtype = parallelogram.dtype
int_dtype = dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
if int_dtype:
# Does not apply the transformation to `int` boxes as the rounding error
# will typically not ensure the resulting box has a rectangular shape.
return parallelogram.clone()

out_boxes = parallelogram.clone()

# Calculate parallelogram diagonal vectors
Expand All @@ -499,8 +486,8 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad)),
)

delta_x = torch.round(w * cos).to(dtype) if int_dtype else w * cos
delta_y = torch.round(w * sin).to(dtype) if int_dtype else w * sin
delta_x = w * cos
delta_y = w * sin
# Update coordinates to form a rectangle
# Keeping the points (x1, y1) and (x3, y3) unchanged.
out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2])
Expand Down Expand Up @@ -1196,8 +1183,6 @@ def _affine_bounding_boxes_with_expand(
).reshape(original_shape)

if need_cast:
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
out_bboxes.round_()
out_bboxes = out_bboxes.to(dtype)
return out_bboxes, canvas_size

Expand Down
18 changes: 5 additions & 13 deletions torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ def _clamp_along_y_axis(
dtype = bounding_boxes.dtype
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
need_cast = dtype not in acceptable_dtypes
eps = 1e-06 # Ensure consistency between CPU and GPU.
original_shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.reshape(-1, 8)
original_bounding_boxes = original_bounding_boxes.reshape(-1, 8)
Expand All @@ -559,27 +558,23 @@ def _clamp_along_y_axis(
case_b[..., 6].clamp_(0) # Clamp x4 to 0
case_c = torch.zeros_like(case_b)

cond_a = (x1 < eps) & ~case_a.isnan().any(-1) # First point is outside left boundary
cond_b = y1.isclose(y2, rtol=eps, atol=eps) | y3.isclose(y4, rtol=eps, atol=eps) # First line is nearly vertical
cond_a = (x1 < 0) & ~case_a.isnan().any(-1) # First point is outside left boundary
cond_b = y1.isclose(y2) | y3.isclose(y4) # First line is nearly vertical
cond_c = (x1 <= 0) & (x2 <= 0) & (x3 <= 0) & (x4 <= 0) # All points outside left boundary
cond_c = (
cond_c
| y1.isclose(y4, rtol=eps, atol=eps)
| y2.isclose(y3, rtol=eps, atol=eps)
| (cond_b & x1.isclose(x2, rtol=eps, atol=eps))
| y1.isclose(y4)
| y2.isclose(y3)
| (cond_b & x1.isclose(x2))
) # First line is nearly horizontal

for (cond, case) in zip(
[cond_a, cond_b, cond_c],
[case_a, case_b, case_c],
):
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
if clamping_mode == "hard":
bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not immediately obvious that this relates to dtypes, so just flagging to make sure this change is intended?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug good catch. Yeah this was related to dtype and introduction of epsilon. I have remove it in a later commit attached to this PR.


if need_cast:
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
bounding_boxes.round_()
bounding_boxes = bounding_boxes.to(dtype)
return bounding_boxes.reshape(original_shape)

Expand Down Expand Up @@ -646,9 +641,6 @@ def _clamp_rotated_bounding_boxes(
).reshape(original_shape)

if need_cast:
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# Adding epsilon to ensure consistency between CPU and GPU rounding.
out_boxes.add_(1e-7).round_()
out_boxes = out_boxes.to(dtype)
return out_boxes

Expand Down
6 changes: 6 additions & 0 deletions torchvision/tv_tensors/_bounding_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_
bounding_boxes.clamping_mode = clamping_mode
return bounding_boxes

@staticmethod
def _check_format(tensor: torch.Tensor, format: BoundingBoxFormat) -> None:
if not torch.is_floating_point(tensor) and is_rotated_bounding_format(format):
raise ValueError("Rotated bounding boxes should be floating point tensors")

def __new__(
cls,
data: Any,
Expand All @@ -111,6 +116,7 @@ def __new__(
requires_grad: bool | None = None,
) -> BoundingBoxes:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
cls._check_format(tensor, format=format)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we call it only once and it's only 2 lines, we can inline it instead of making it a method. Also it might be best to do the validation before as the very first step, before calling cls._to_tensor()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug, I included it within the function. However, it possible to pass list as input so it easier to run this test after the input has been converted to a tensor.

return cls._wrap(tensor, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)

@classmethod
Expand Down
Loading