Skip to content

Commit da7f360

Browse files
committed
Add clamping_mode parameter to clamp_bounding_boxes functional and class
1 parent 84379b5 commit da7f360

File tree

5 files changed

+88
-18
lines changed

5 files changed

+88
-18
lines changed

test/common_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def make_bounding_boxes(
410410
canvas_size=DEFAULT_SIZE,
411411
*,
412412
format=tv_tensors.BoundingBoxFormat.XYXY,
413+
clamping_mode="soft",
413414
num_boxes=1,
414415
dtype=None,
415416
device="cpu",
@@ -474,7 +475,7 @@ def sample_position(values, max_value):
474475
# numerical issues during the testing
475476
buffer = 4
476477
out_boxes = clamp_bounding_boxes(
477-
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer)
478+
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer), clamping_mode=clamping_mode
478479
)
479480
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
480481
out_boxes[:, :2] += buffer // 2

test/test_transforms_v2.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5506,20 +5506,23 @@ def test_correctness_image(self, mean, std, dtype, fn):
55065506

55075507
class TestClampBoundingBoxes:
55085508
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5509+
@pytest.mark.parametrize("clamping_mode", ("hard", "none")) # TODOBB add soft
55095510
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
55105511
@pytest.mark.parametrize("device", cpu_and_cuda())
5511-
def test_kernel(self, format, dtype, device):
5512-
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
5512+
def test_kernel(self, format, clamping_mode, dtype, device):
5513+
bounding_boxes = make_bounding_boxes(format=format, clamping_mode=clamping_mode, dtype=dtype, device=device)
55135514
check_kernel(
55145515
F.clamp_bounding_boxes,
55155516
bounding_boxes,
55165517
format=bounding_boxes.format,
55175518
canvas_size=bounding_boxes.canvas_size,
5519+
clamping_mode=clamping_mode,
55185520
)
55195521

55205522
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5521-
def test_functional(self, format):
5522-
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))
5523+
@pytest.mark.parametrize("clamping_mode", ("hard", "none")) # TODOBB add soft
5524+
def test_functional(self, format, clamping_mode):
5525+
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format, clamping_mode=clamping_mode))
55235526

55245527
def test_errors(self):
55255528
input_tv_tensor = make_bounding_boxes()
@@ -5540,6 +5543,47 @@ def test_errors(self):
55405543

55415544
def test_transform(self):
55425545
check_transform(transforms.ClampBoundingBoxes(), make_bounding_boxes())
5546+
5547+
@pytest.mark.parametrize("rotated", (True, False))
5548+
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none"))
5549+
@pytest.mark.parametrize("clamping_mode", ("hard", "none", None)) # TODOBB add soft here.
5550+
@pytest.mark.parametrize("pass_pure_tensor", (True, False))
5551+
@pytest.mark.parametrize("fn", [F.clamp_bounding_boxes, transform_cls_to_functional(transforms.ClampBoundingBoxes)])
5552+
def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn):
5553+
# This test checks 2 things:
5554+
# - That passing clamping_mode=None to the clamp_bounding_boxes
5555+
# functional (or to the class) relies on the box's `.clamping_mode`
5556+
# attribute
5557+
# - That clamping happens when it should, and only when it should, i.e.
5558+
# when the clamping mode is not "none". It doesn't validate the
5559+
# nunmerical results, only that clamping happened. For that, we create
5560+
# a large 100x100 box inside of a small 10x10 image.
5561+
5562+
if pass_pure_tensor and fn is not F.clamp_bounding_boxes:
5563+
# Only the functional supports pure tensors, not the class
5564+
return
5565+
if pass_pure_tensor and clamping_mode is None:
5566+
# cannot leave clamping_mode=None when passing pure tensor
5567+
return
5568+
5569+
if rotated:
5570+
boxes = tv_tensors.BoundingBoxes([0, 0, 100, 100, 0], format="XYWHR", canvas_size=(10, 10), clamping_mode=constructor_clamping_mode)
5571+
expected_clamped_output = torch.tensor([[0, 0, 10, 10, 0]])
5572+
else:
5573+
boxes = tv_tensors.BoundingBoxes([0, 100, 0, 100], format="XYXY", canvas_size=(10, 10), clamping_mode=constructor_clamping_mode)
5574+
expected_clamped_output = torch.tensor([[0, 10, 0, 10]])
5575+
5576+
if pass_pure_tensor:
5577+
out = fn(boxes.as_subclass(torch.Tensor), format=boxes.format, canvas_size=boxes.canvas_size, clamping_mode=clamping_mode)
5578+
else:
5579+
out = fn(boxes, clamping_mode=clamping_mode)
5580+
5581+
clamping_mode_prevailing = constructor_clamping_mode if clamping_mode is None else clamping_mode
5582+
if clamping_mode_prevailing == "none":
5583+
assert_equal(boxes, out) # should be a pass-through
5584+
else:
5585+
assert_equal(out, expected_clamped_output)
5586+
55435587

55445588

55455589
class TestClampKeyPoints:

torchvision/transforms/v2/_meta.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from torchvision import tv_tensors
44
from torchvision.transforms.v2 import functional as F, Transform
5+
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
56

67

78
class ConvertBoundingBoxFormat(Transform):
@@ -28,12 +29,18 @@ class ClampBoundingBoxes(Transform):
2829
2930
The clamping is done according to the bounding boxes' ``canvas_size`` meta-data.
3031
32+
Args:
33+
clamping_mode: TODOBB more docs. Default is None which relies on the input box' .clamping_mode attribute.
34+
3135
"""
36+
def __init__(self, clamping_mode: CLAMPING_MODE_TYPE = None) -> None:
37+
super().__init__()
38+
self.clamping_mode = clamping_mode
3239

3340
_transformed_types = (tv_tensors.BoundingBoxes,)
3441

3542
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
36-
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]
43+
return F.clamp_bounding_boxes(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value]
3744

3845

3946
class ClampKeyPoints(Transform):

torchvision/transforms/v2/functional/_meta.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torchvision import tv_tensors
66
from torchvision.transforms import _functional_pil as _FP
77
from torchvision.tv_tensors import BoundingBoxFormat
8+
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
89

910
from torchvision.utils import _log_api_usage_once
1011

@@ -370,8 +371,11 @@ def convert_bounding_box_format(
370371

371372

372373
def _clamp_bounding_boxes(
373-
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int]
374+
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int],
375+
clamping_mode: Optional[CLAMPING_MODE_TYPE], # TODOBB shouldn't be Optional
374376
) -> torch.Tensor:
377+
if clamping_mode is not None and clamping_mode == "none":
378+
return bounding_boxes.clone()
375379
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
376380
# BoundingBoxFormat instead of converting back and forth
377381
in_dtype = bounding_boxes.dtype
@@ -477,7 +481,8 @@ def _clamp_along_y_axis(
477481

478482

479483
def _clamp_rotated_bounding_boxes(
480-
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int]
484+
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int],
485+
clamping_mode: Optional[CLAMPING_MODE_TYPE], # TODOBB shouldn't be Optional
481486
) -> torch.Tensor:
482487
"""
483488
Clamp rotated bounding boxes to ensure they stay within the canvas boundaries.
@@ -499,6 +504,8 @@ def _clamp_rotated_bounding_boxes(
499504
Returns:
500505
torch.Tensor: Clamped bounding boxes in the original format and shape
501506
"""
507+
if clamping_mode is not None and clamping_mode == "none":
508+
return bounding_boxes.clone()
502509
original_shape = bounding_boxes.shape
503510
dtype = bounding_boxes.dtype
504511
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
@@ -536,29 +543,33 @@ def clamp_bounding_boxes(
536543
inpt: torch.Tensor,
537544
format: Optional[BoundingBoxFormat] = None,
538545
canvas_size: Optional[tuple[int, int]] = None,
546+
clamping_mode: Optional[CLAMPING_MODE_TYPE] = None,
539547
) -> torch.Tensor:
540548
"""See :func:`~torchvision.transforms.v2.ClampBoundingBoxes` for details."""
541549
if not torch.jit.is_scripting():
542550
_log_api_usage_once(clamp_bounding_boxes)
543551

544552
if torch.jit.is_scripting() or is_pure_tensor(inpt):
545553

546-
if format is None or canvas_size is None:
547-
raise ValueError("For pure tensor inputs, `format` and `canvas_size` have to be passed.")
554+
# TODOBB
555+
if format is None or canvas_size is None or clamping_mode is None:
556+
raise ValueError("For pure tensor inputs, `format`, `canvas_size` and `clamping_mode` have to be passed.")
548557
if tv_tensors.is_rotated_bounding_format(format):
549-
return _clamp_rotated_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
558+
return _clamp_rotated_bounding_boxes(inpt, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
550559
else:
551-
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
560+
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
552561
elif isinstance(inpt, tv_tensors.BoundingBoxes):
553562
if format is not None or canvas_size is not None:
554563
raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.")
564+
if clamping_mode is None:
565+
clamping_mode = inpt.clamping_mode
555566
if tv_tensors.is_rotated_bounding_format(inpt.format):
556567
output = _clamp_rotated_bounding_boxes(
557-
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
568+
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, clamping_mode=clamping_mode
558569
)
559570
else:
560571
output = _clamp_bounding_boxes(
561-
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
572+
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, clamping_mode=clamping_mode
562573
)
563574
return tv_tensors.wrap(output, like=inpt)
564575
else:

torchvision/tv_tensors/_bounding_boxes.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Mapping, Sequence
44

55
from enum import Enum
6-
from typing import Any, Literal
6+
from typing import Any
77

88
import torch
99
from torch.utils._pytree import tree_flatten
@@ -46,7 +46,12 @@ def is_rotated_bounding_format(format: BoundingBoxFormat) -> bool:
4646
)
4747

4848

49-
CLAMPING_MODE_TYPE = Literal["hard", "soft", "none"]
49+
# TODOBB consider making this a Literal instead. Tried briefly and got
50+
# torchscript errors, leaving to str for now.
51+
# CLAMPING_MODE_TYPE = Literal["hard", "soft", "none"]
52+
CLAMPING_MODE_TYPE = str
53+
54+
# TODOBB All docs. Add any new API to rst files, add tutorial[s].
5055

5156

5257
class BoundingBoxes(TVTensor):
@@ -65,6 +70,7 @@ class BoundingBoxes(TVTensor):
6570
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
6671
format (BoundingBoxFormat, str): Format of the bounding box.
6772
canvas_size (two-tuple of ints): Height and width of the corresponding image or video.
73+
clamping_mode: TODOBB
6874
dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
6975
``data``.
7076
device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
@@ -89,6 +95,7 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_
8995
bounding_boxes = tensor.as_subclass(cls)
9096
bounding_boxes.format = format
9197
bounding_boxes.canvas_size = canvas_size
98+
# TODOBB validate values
9299
bounding_boxes.clamping_mode = clamping_mode
93100
return bounding_boxes
94101

@@ -98,13 +105,13 @@ def __new__(
98105
*,
99106
format: BoundingBoxFormat | str,
100107
canvas_size: tuple[int, int],
101-
clamping_mode: CLAMPING_MODE_TYPE = "soft",
108+
clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB change default to soft!
102109
dtype: torch.dtype | None = None,
103110
device: torch.device | str | int | None = None,
104111
requires_grad: bool | None = None,
105112
) -> BoundingBoxes:
106113
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
107-
return cls._wrap(tensor, format=format, canvas_size=canvas_size)
114+
return cls._wrap(tensor, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
108115

109116
@classmethod
110117
def _wrap_output(

0 commit comments

Comments
 (0)