Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,10 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
def make_keypoints(canvas_size=DEFAULT_SIZE, *, clamping_mode="soft", num_points=4, dtype=None, device="cpu"):
y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device)
x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device)
return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size)
return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size, clamping_mode=clamping_mode)


def make_bounding_boxes(
Expand Down
79 changes: 66 additions & 13 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def affine_rotated_bounding_boxes(bounding_boxes):

def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, clamp=True):
canvas_size = new_canvas_size or keypoints.canvas_size
clamping_mode = keypoints.clamping_mode

def affine_keypoints(keypoints):
dtype = keypoints.dtype
Expand All @@ -652,7 +653,7 @@ def affine_keypoints(keypoints):
)

if clamp:
output = F.clamp_keypoints(output, canvas_size=canvas_size)
output = F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode)
else:
dtype = output.dtype

Expand All @@ -661,6 +662,7 @@ def affine_keypoints(keypoints):
return tv_tensors.KeyPoints(
torch.cat([affine_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(keypoints.shape),
canvas_size=canvas_size,
clamping_mode=clamping_mode,
)


Expand Down Expand Up @@ -2084,7 +2086,6 @@ def test_functional(self, make_input):
(F.rotate_image, tv_tensors.Image),
(F.rotate_mask, tv_tensors.Mask),
(F.rotate_video, tv_tensors.Video),
(F.rotate_keypoints, tv_tensors.KeyPoints),
],
)
def test_functional_signature(self, kernel, input_type):
Expand Down Expand Up @@ -3309,7 +3310,6 @@ def test_functional(self, make_input):
(F.elastic_image, tv_tensors.Image),
(F.elastic_mask, tv_tensors.Mask),
(F.elastic_video, tv_tensors.Video),
(F.elastic_keypoints, tv_tensors.KeyPoints),
],
)
def test_functional_signature(self, kernel, input_type):
Expand Down Expand Up @@ -4414,7 +4414,6 @@ def test_functional(self, make_input):
(F.resized_crop_image, tv_tensors.Image),
(F.resized_crop_mask, tv_tensors.Mask),
(F.resized_crop_video, tv_tensors.Video),
(F.resized_crop_keypoints, tv_tensors.KeyPoints),
],
)
def test_functional_signature(self, kernel, input_type):
Expand Down Expand Up @@ -5325,6 +5324,7 @@ def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, fo

def _reference_perspective_keypoints(self, keypoints, *, startpoints, endpoints):
canvas_size = keypoints.canvas_size
clamping_mode = keypoints.clamping_mode
dtype = keypoints.dtype
device = keypoints.device

Expand Down Expand Up @@ -5361,16 +5361,16 @@ def perspective_keypoints(keypoints):
)

# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
return F.clamp_keypoints(
output,
canvas_size=canvas_size,
).to(dtype=dtype, device=device)
return F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode).to(
dtype=dtype, device=device
)

return tv_tensors.KeyPoints(
torch.cat([perspective_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(
keypoints.shape
),
canvas_size=canvas_size,
clamping_mode=clamping_mode,
)

@pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS)
Expand Down Expand Up @@ -5733,32 +5733,85 @@ def test_error(self):


class TestClampKeyPoints:
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None))
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, dtype, device):
keypoints = make_keypoints(dtype=dtype, device=device)
def test_kernel(self, clamping_mode, dtype, device):
keypoints = make_keypoints(dtype=dtype, device=device, clamping_mode=clamping_mode)
check_kernel(
F.clamp_keypoints,
keypoints,
canvas_size=keypoints.canvas_size,
clamping_mode=clamping_mode,
)

def test_functional(self):
check_functional(F.clamp_keypoints, make_keypoints())
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None))
def test_functional(self, clamping_mode):
check_functional(F.clamp_keypoints, make_keypoints(clamping_mode=clamping_mode))

def test_errors(self):
input_tv_tensor = make_keypoints()
input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)

with pytest.raises(ValueError, match="`canvas_size` has to be passed"):
with pytest.raises(ValueError, match="`canvas_size` and `clamping_mode` have to be passed."):
F.clamp_keypoints(input_pure_tensor, canvas_size=None)

with pytest.raises(ValueError, match="`canvas_size` must not be passed"):
F.clamp_keypoints(input_tv_tensor, canvas_size=input_tv_tensor.canvas_size)
with pytest.raises(ValueError, match="clamping_mode must be soft,"):
F.clamp_keypoints(input_tv_tensor, clamping_mode="bad")
with pytest.raises(ValueError, match="clamping_mode must be soft,"):
transforms.ClampKeyPoints(clamping_mode="bad")(input_tv_tensor)

def test_transform(self):
check_transform(transforms.ClampKeyPoints(), make_keypoints())

@pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None))
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None, "auto"))
@pytest.mark.parametrize("pass_pure_tensor", (True, False))
@pytest.mark.parametrize("fn", [F.clamp_keypoints, transform_cls_to_functional(transforms.ClampKeyPoints)])
def test_clamping_mode(self, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn):
# This test checks 2 things:
# - That passing clamping_mode=None to the clamp_keypointss
# functional (or to the class) relies on the box's `.clamping_mode`
# attribute
# - That clamping happens when it should, and only when it should, i.e.
# when the clamping mode is not None. It doesn't validate the
# numerical results, only that clamping happened. For that, we create
# a keypoints with large coordinates (100) inside of a small 10x10 image.

if pass_pure_tensor and fn is not F.clamp_keypoints:
# Only the functional supports pure tensors, not the class
return
if pass_pure_tensor and clamping_mode == "auto":
# cannot leave clamping_mode="auto" when passing pure tensor
return

keypoints = tv_tensors.KeyPoints(
[[0, 100], [0, 100]], canvas_size=(10, 10), clamping_mode=constructor_clamping_mode
)
expected_clamped_output = (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this line is redundant and can be removed.

torch.tensor([[0, 10], [0, 10]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]])
)
expected_clamped_output = (
torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]])
)

if pass_pure_tensor:
out = fn(
keypoints.as_subclass(torch.Tensor),
canvas_size=keypoints.canvas_size,
clamping_mode=clamping_mode,
)
else:
out = fn(keypoints, clamping_mode=clamping_mode)

clamping_mode_prevailing = constructor_clamping_mode if clamping_mode == "auto" else clamping_mode
if clamping_mode_prevailing is None:
assert_equal(keypoints, out) # should be a pass-through
else:
assert_equal(out, expected_clamped_output)


class TestInvert:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.float32])
Expand Down
21 changes: 16 additions & 5 deletions torchvision/transforms/v2/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
from torchvision.tv_tensors import CLAMPING_MODE_TYPE


class ConvertBoundingBoxFormat(Transform):
Expand Down Expand Up @@ -46,17 +46,27 @@ def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> t
class ClampKeyPoints(Transform):
"""Clamp keypoints to their corresponding image dimensions.

The clamping is done according to the keypoints' ``canvas_size`` meta-data.
Args:
clamping_mode: Default is "auto" which relies on the input keypoint'
``clamping_mode`` attribute.
The clamping is done according to the keypoints' ``canvas_size`` meta-data.
Read more in :ref:`clamping_mode_tuto`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clamping_mode_tuto in the docs currently only covers bounding boxes and would need to be updated as well.

for more details on how to use this transform.

"""

def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None:
super().__init__()
self.clamping_mode = clamping_mode

_transformed_types = (tv_tensors.KeyPoints,)

def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints:
return F.clamp_keypoints(inpt) # type: ignore[return-value]
return F.clamp_keypoints(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value]


class SetClampingMode(Transform):
"""Sets the ``clamping_mode`` attribute of the bounding boxes for future transforms.
"""Sets the ``clamping_mode`` attribute of the bounding boxes and keypoints for future transforms.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be useful to allow setting the clamping modes of bounding boxes and keypoints to different values by passing a dictionary.

For example:

  • SetClampingMode("soft") sets the clamping mode of both bounding boxes and keypoints to "soft".
  • SetClampingMode({tv_tensors.BoundingBoxes: "hard", tv_tensors.KeyPoints: "soft"}) sets the clamping mode of bounding boxes to "hard" and that of keypoints to "soft".
  • SetClampingMode({tv_tensors.BoundingBoxes: "hard"}) sets the clamping mode of bounding boxes to "hard" and leaves that of keypoints unchanged.




Expand All @@ -73,9 +83,10 @@ def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None:
if self.clamping_mode not in (None, "soft", "hard"):
raise ValueError(f"clamping_mode must be soft, hard or None, got {clamping_mode}")

_transformed_types = (tv_tensors.BoundingBoxes,)
_transformed_types = (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints)

def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
# this method works for both `tv_tensors.BoundingBoxes`` and `tv_tensors.KeyPoints`.
out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment]
out.clamping_mode = self.clamping_mode
return out
Loading