Skip to content

Commit 01d3452

Browse files
committed
Add SetClampingMode transform
1 parent 51ca83e commit 01d3452

File tree

5 files changed

+72
-7
lines changed

5 files changed

+72
-7
lines changed

test/common_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def make_bounding_boxes(
410410
canvas_size=DEFAULT_SIZE,
411411
*,
412412
format=tv_tensors.BoundingBoxFormat.XYXY,
413-
clamping_mode="soft",
413+
clamping_mode="hard", # TODOBB
414414
num_boxes=1,
415415
dtype=None,
416416
device="cpu",
@@ -481,7 +481,7 @@ def sample_position(values, max_value):
481481
out_boxes[:, :2] += buffer // 2
482482
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
483483
out_boxes[:, :] += buffer // 2
484-
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size)
484+
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
485485

486486

487487
def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):

test/test_transforms_v2.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5589,6 +5589,54 @@ def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode,
55895589
else:
55905590
assert_equal(out, expected_clamped_output)
55915591

5592+
class TestSetClampingMode:
5593+
5594+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5595+
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none")) # TODOBB add soft
5596+
@pytest.mark.parametrize("desired_clamping_mode", ("hard", "none")) # TODOBB add soft
5597+
def test_setter(self, format, constructor_clamping_mode, desired_clamping_mode):
5598+
5599+
in_boxes = make_bounding_boxes(format=format, clamping_mode=constructor_clamping_mode)
5600+
out_boxes = transforms.SetClampingMode(clamping_mode=desired_clamping_mode)(in_boxes)
5601+
5602+
assert in_boxes.clamping_mode == constructor_clamping_mode # input is unchanged: no leak
5603+
assert out_boxes.clamping_mode == desired_clamping_mode
5604+
5605+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5606+
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none")) # TODOBB add soft
5607+
def test_pipeline_no_leak(self, format, constructor_clamping_mode):
5608+
5609+
class AssertClampingMode(transforms.Transform):
5610+
def __init__(self, expected_clamping_mode):
5611+
super().__init__()
5612+
self.expected_clamping_mode = expected_clamping_mode
5613+
5614+
_transformed_types = (tv_tensors.BoundingBoxes,)
5615+
5616+
def transform(self, inpt, _):
5617+
assert inpt.clamping_mode == self.expected_clamping_mode
5618+
return inpt
5619+
5620+
t = transforms.Compose(
5621+
[
5622+
transforms.SetClampingMode("none"),
5623+
AssertClampingMode("none"),
5624+
transforms.SetClampingMode("hard"),
5625+
AssertClampingMode("hard"),
5626+
transforms.SetClampingMode("none"),
5627+
AssertClampingMode("none"),
5628+
transforms.ClampBoundingBoxes("hard")
5629+
]
5630+
)
5631+
5632+
in_boxes = make_bounding_boxes(format=format, clamping_mode=constructor_clamping_mode)
5633+
out_boxes = t(in_boxes)
5634+
5635+
assert in_boxes.clamping_mode == constructor_clamping_mode # input is unchanged: no leak
5636+
5637+
# assert that the output boxes clamping_mode is the one set by the last SetClampingMode.
5638+
# ClampBoundingBoxes doesn't set clamping_mode.
5639+
assert out_boxes.clamping_mode == "none"
55925640

55935641

55945642
class TestClampKeyPoints:
@@ -7376,3 +7424,4 @@ def test_different_sizes(self, make_input1, make_input2, query):
73767424
def test_no_valid_input(self, query):
73777425
with pytest.raises(TypeError, match="No image"):
73787426
query(["blah"])
7427+

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
ScaleJitter,
4242
TenCrop,
4343
)
44-
from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat
44+
from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat, SetClampingMode
4545
from ._misc import (
4646
ConvertImageDtype,
4747
GaussianBlur,

torchvision/transforms/v2/_meta.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Union
1+
from typing import Any, Union, Optional
22

33
from torchvision import tv_tensors
44
from torchvision.transforms.v2 import functional as F, Transform
@@ -30,10 +30,10 @@ class ClampBoundingBoxes(Transform):
3030
The clamping is done according to the bounding boxes' ``canvas_size`` meta-data.
3131
3232
Args:
33-
clamping_mode: TODOBB more docs. Default is None which relies on the input box' .clamping_mode attribute.
33+
clamping_mode: TODOBB more docs. Default is None which relies on the input box' clamping_mode attribute.
3434
3535
"""
36-
def __init__(self, clamping_mode: CLAMPING_MODE_TYPE = None) -> None:
36+
def __init__(self, clamping_mode: Optional[CLAMPING_MODE_TYPE] = None) -> None:
3737
super().__init__()
3838
self.clamping_mode = clamping_mode
3939

@@ -53,3 +53,18 @@ class ClampKeyPoints(Transform):
5353

5454
def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints:
5555
return F.clamp_keypoints(inpt) # type: ignore[return-value]
56+
57+
58+
class SetClampingMode(Transform):
59+
"""TODOBB"""
60+
def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None:
61+
super().__init__()
62+
# TODOBB validate mode
63+
self.clamping_mode = clamping_mode
64+
65+
_transformed_types = (tv_tensors.BoundingBoxes,)
66+
67+
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
68+
out = inpt.clone()
69+
out.clamping_mode = self.clamping_mode
70+
return out

torchvision/tv_tensors/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ def wrap(wrappee, *, like, **kwargs):
2323
wrappee (Tensor): The tensor to convert.
2424
like (:class:`~torchvision.tv_tensors.TVTensor`): The reference.
2525
``wrappee`` will be converted into the same subclass as ``like``.
26-
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`.
26+
kwargs: Can contain "format", "canvas_size" and "clamping_mode" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`.
2727
Ignored otherwise.
2828
"""
2929
if isinstance(like, BoundingBoxes):
3030
return BoundingBoxes._wrap(
3131
wrappee,
3232
format=kwargs.get("format", like.format),
3333
canvas_size=kwargs.get("canvas_size", like.canvas_size),
34+
clamping_mode=kwargs.get("clamping_mode", like.clamping_mode),
3435
)
3536
elif isinstance(like, KeyPoints):
3637
return KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size))

0 commit comments

Comments
 (0)