Skip to content

Commit e417da1

Browse files
disable int rotated boxes
1 parent 80cb38e commit e417da1

File tree

5 files changed

+49
-55
lines changed

5 files changed

+49
-55
lines changed

test/common_utils.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -424,13 +424,6 @@ def sample_position(values, max_value):
424424
format = tv_tensors.BoundingBoxFormat[format]
425425

426426
dtype = dtype or torch.float32
427-
int_dtype = dtype in (
428-
torch.uint8,
429-
torch.int8,
430-
torch.int16,
431-
torch.int32,
432-
torch.int64,
433-
)
434427

435428
h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size)
436429
y = sample_position(h, canvas_size[0])
@@ -457,14 +450,14 @@ def sample_position(values, max_value):
457450
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
458451
r_rad = r * torch.pi / 180.0
459452
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
460-
x1 = torch.round(x) if int_dtype else x
461-
y1 = torch.round(y) if int_dtype else y
462-
x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos
463-
y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin
464-
x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin
465-
y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos
466-
x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin
467-
y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos
453+
x1 = x
454+
y1 = y
455+
x2 = x1 + w * cos
456+
y2 = y1 - w * sin
457+
x3 = x2 + h * sin
458+
y3 = y2 + h * cos
459+
x4 = x1 + h * sin
460+
y4 = y1 + h * cos
468461
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
469462
else:
470463
raise ValueError(f"Format {format} is not supported")

test/test_transforms_v2.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -564,13 +564,6 @@ def reference_affine_rotated_bounding_boxes_helper(
564564

565565
def affine_rotated_bounding_boxes(bounding_boxes):
566566
dtype = bounding_boxes.dtype
567-
int_dtype = dtype in (
568-
torch.uint8,
569-
torch.int8,
570-
torch.int16,
571-
torch.int32,
572-
torch.int64,
573-
)
574567
device = bounding_boxes.device
575568

576569
# Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1
@@ -605,17 +598,12 @@ def affine_rotated_bounding_boxes(bounding_boxes):
605598
)
606599

607600
output = output[[2, 3, 0, 1, 6, 7, 4, 5]] if flip else output
608-
if not int_dtype:
609-
output = _parallelogram_to_bounding_boxes(output)
601+
output = _parallelogram_to_bounding_boxes(output)
610602

611603
output = F.convert_bounding_box_format(
612604
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
613605
)
614606

615-
if torch.is_floating_point(output) and int_dtype:
616-
# It is important to round before cast.
617-
output = torch.round(output)
618-
619607
# For rotated boxes, it is important to cast before clamping.
620608
return (
621609
F.clamp_bounding_boxes(
@@ -760,6 +748,8 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype,
760748
def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device):
761749
if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
762750
return
751+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
752+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
763753

764754
bounding_boxes = make_bounding_boxes(
765755
format=format,
@@ -1212,6 +1202,8 @@ def test_kernel_image(self, dtype, device):
12121202
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
12131203
@pytest.mark.parametrize("device", cpu_and_cuda())
12141204
def test_kernel_bounding_boxes(self, format, dtype, device):
1205+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
1206+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
12151207
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
12161208
check_kernel(
12171209
F.horizontal_flip_bounding_boxes,
@@ -1441,6 +1433,8 @@ def test_kernel_image(self, param, value, dtype, device):
14411433
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
14421434
@pytest.mark.parametrize("device", cpu_and_cuda())
14431435
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
1436+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
1437+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
14441438
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
14451439
self._check_kernel(
14461440
F.affine_bounding_boxes,
@@ -1823,6 +1817,8 @@ def test_kernel_image(self, dtype, device):
18231817
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
18241818
@pytest.mark.parametrize("device", cpu_and_cuda())
18251819
def test_kernel_bounding_boxes(self, format, dtype, device):
1820+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
1821+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
18261822
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
18271823
check_kernel(
18281824
F.vertical_flip_bounding_boxes,
@@ -2021,6 +2017,8 @@ def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
20212017
kwargs = {param: value}
20222018
if param != "angle":
20232019
kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"]
2020+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
2021+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
20242022

20252023
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
20262024

@@ -3236,6 +3234,8 @@ def test_kernel_image(self, param, value, dtype, device):
32363234
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
32373235
@pytest.mark.parametrize("device", cpu_and_cuda())
32383236
def test_kernel_bounding_boxes(self, format, dtype, device):
3237+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
3238+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
32393239
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
32403240

32413241
check_kernel(
@@ -3399,6 +3399,8 @@ def test_kernel_image(self, kwargs, dtype, device):
33993399
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
34003400
@pytest.mark.parametrize("device", cpu_and_cuda())
34013401
def test_kernel_bounding_boxes(self, kwargs, format, dtype, device):
3402+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
3403+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
34023404
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
34033405
check_kernel(F.crop_bounding_boxes, bounding_boxes, format=format, **kwargs)
34043406

@@ -3576,6 +3578,8 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w
35763578
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
35773579
@pytest.mark.parametrize("device", cpu_and_cuda())
35783580
def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device):
3581+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
3582+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
35793583
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
35803584

35813585
actual = F.crop(bounding_boxes, **kwargs)
@@ -3590,6 +3594,8 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device
35903594
@pytest.mark.parametrize("device", cpu_and_cuda())
35913595
@pytest.mark.parametrize("seed", list(range(5)))
35923596
def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, device, seed):
3597+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
3598+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
35933599
input_size = [s * 2 for s in output_size]
35943600
bounding_boxes = make_bounding_boxes(input_size, format=format, dtype=dtype, device=device)
35953601

@@ -4267,6 +4273,10 @@ def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
42674273
@pytest.mark.parametrize("device", cpu_and_cuda())
42684274
@pytest.mark.parametrize("fn_type", ["functional", "transform"])
42694275
def test_correctness(self, old_format, new_format, dtype, device, fn_type):
4276+
if not dtype.is_floating_point and (
4277+
tv_tensors.is_rotated_bounding_format(old_format) or tv_tensors.is_rotated_bounding_format(new_format)
4278+
):
4279+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
42704280
bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device)
42714281

42724282
if fn_type == "functional":
@@ -4706,6 +4716,8 @@ def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding):
47064716
@pytest.mark.parametrize("device", cpu_and_cuda())
47074717
@pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
47084718
def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn):
4719+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
4720+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
47094721
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
47104722

47114723
actual = fn(bounding_boxes, padding=padding)
@@ -4876,6 +4888,8 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
48764888
@pytest.mark.parametrize("device", cpu_and_cuda())
48774889
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
48784890
def test_bounding_boxes_correctness(self, output_size, format, dtype, device, fn):
4891+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
4892+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
48794893
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
48804894

48814895
actual = fn(bounding_boxes, output_size)
@@ -5242,6 +5256,8 @@ def perspective_bounding_boxes(bounding_boxes):
52425256
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
52435257
@pytest.mark.parametrize("device", cpu_and_cuda())
52445258
def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device):
5259+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
5260+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
52455261
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
52465262

52475263
actual = F.perspective(bounding_boxes, startpoints=startpoints, endpoints=endpoints)
@@ -5511,6 +5527,8 @@ class TestClampBoundingBoxes:
55115527
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
55125528
@pytest.mark.parametrize("device", cpu_and_cuda())
55135529
def test_kernel(self, format, clamping_mode, dtype, device):
5530+
if not dtype.is_floating_point and tv_tensors.is_rotated_bounding_format(format):
5531+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
55145532
bounding_boxes = make_bounding_boxes(format=format, clamping_mode=clamping_mode, dtype=dtype, device=device)
55155533
check_kernel(
55165534
F.clamp_bounding_boxes,
@@ -6938,14 +6956,11 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t
69386956

69396957

69406958
@pytest.mark.parametrize("input_size", [(17, 11), (11, 17), (11, 11)])
6941-
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
69426959
@pytest.mark.parametrize("device", cpu_and_cuda())
6943-
def test_parallelogram_to_bounding_boxes(input_size, dtype, device):
6960+
def test_parallelogram_to_bounding_boxes(input_size, device):
69446961
# Assert that applying `_parallelogram_to_bounding_boxes` to rotated boxes
69456962
# does not modify the input.
6946-
bounding_boxes = make_bounding_boxes(
6947-
input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, dtype=dtype, device=device
6948-
)
6963+
bounding_boxes = make_bounding_boxes(input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, device=device)
69496964
actual = _parallelogram_to_bounding_boxes(bounding_boxes)
69506965
torch.testing.assert_close(actual, bounding_boxes, rtol=0, atol=1)
69516966

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,6 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
462462
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
463463
The output maintains the same dtype as the input.
464464
"""
465-
dtype = parallelogram.dtype
466-
int_dtype = dtype in (
467-
torch.uint8,
468-
torch.int8,
469-
torch.int16,
470-
torch.int32,
471-
torch.int64,
472-
)
473-
if int_dtype:
474-
# Does not apply the transformation to `int` boxes as the rounding error
475-
# will typically not ensure the resulting box has a rectangular shape.
476-
return parallelogram.clone()
477-
478465
out_boxes = parallelogram.clone()
479466

480467
# Calculate parallelogram diagonal vectors
@@ -499,8 +486,8 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
499486
diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad)),
500487
)
501488

502-
delta_x = torch.round(w * cos).to(dtype) if int_dtype else w * cos
503-
delta_y = torch.round(w * sin).to(dtype) if int_dtype else w * sin
489+
delta_x = w * cos
490+
delta_y = w * sin
504491
# Update coordinates to form a rectangle
505492
# Keeping the points (x1, y1) and (x3, y3) unchanged.
506493
out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2])
@@ -1196,8 +1183,6 @@ def _affine_bounding_boxes_with_expand(
11961183
).reshape(original_shape)
11971184

11981185
if need_cast:
1199-
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
1200-
out_bboxes.round_()
12011186
out_bboxes = out_bboxes.to(dtype)
12021187
return out_bboxes, canvas_size
12031188

torchvision/transforms/v2/functional/_meta.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,6 @@ def _clamp_along_y_axis(
578578
bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0
579579

580580
if need_cast:
581-
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
582-
bounding_boxes.round_()
583581
bounding_boxes = bounding_boxes.to(dtype)
584582
return bounding_boxes.reshape(original_shape)
585583

@@ -646,9 +644,6 @@ def _clamp_rotated_bounding_boxes(
646644
).reshape(original_shape)
647645

648646
if need_cast:
649-
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
650-
# Adding epsilon to ensure consistency between CPU and GPU rounding.
651-
out_boxes.add_(1e-7).round_()
652647
out_boxes = out_boxes.to(dtype)
653648
return out_boxes
654649

torchvision/tv_tensors/_bounding_boxes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_
9999
bounding_boxes.clamping_mode = clamping_mode
100100
return bounding_boxes
101101

102+
@staticmethod
103+
def _check_format(tensor: torch.Tensor, format: BoundingBoxFormat) -> None:
104+
if not torch.is_floating_point(tensor) and is_rotated_bounding_format(format):
105+
raise ValueError("Rotated bounding boxes should be floating point tensors")
106+
102107
def __new__(
103108
cls,
104109
data: Any,
@@ -111,6 +116,7 @@ def __new__(
111116
requires_grad: bool | None = None,
112117
) -> BoundingBoxes:
113118
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
119+
cls._check_format(tensor, format=format)
114120
return cls._wrap(tensor, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
115121

116122
@classmethod

0 commit comments

Comments
 (0)