diff --git a/test/common_utils.py b/test/common_utils.py index 74ad31fea72..6dbf34db9bb 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,10 +400,10 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): +def make_keypoints(canvas_size=DEFAULT_SIZE, *, clamping_mode="soft", num_points=4, dtype=None, device="cpu"): y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device) x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device) - return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size) + return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size, clamping_mode=clamping_mode) def make_bounding_boxes( diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f92f2a0bc67..b5a8e4788c3 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -633,6 +633,7 @@ def affine_rotated_bounding_boxes(bounding_boxes): def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, clamp=True): canvas_size = new_canvas_size or keypoints.canvas_size + clamping_mode = keypoints.clamping_mode def affine_keypoints(keypoints): dtype = keypoints.dtype @@ -652,7 +653,7 @@ def affine_keypoints(keypoints): ) if clamp: - output = F.clamp_keypoints(output, canvas_size=canvas_size) + output = F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode) else: dtype = output.dtype @@ -661,6 +662,7 @@ def affine_keypoints(keypoints): return tv_tensors.KeyPoints( torch.cat([affine_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(keypoints.shape), canvas_size=canvas_size, + clamping_mode=clamping_mode, ) @@ -2084,7 +2086,6 @@ def test_functional(self, make_input): (F.rotate_image, tv_tensors.Image), (F.rotate_mask, tv_tensors.Mask), (F.rotate_video, tv_tensors.Video), - (F.rotate_keypoints, tv_tensors.KeyPoints), ], ) def test_functional_signature(self, kernel, input_type): @@ -3309,7 +3310,6 @@ def test_functional(self, make_input): (F.elastic_image, tv_tensors.Image), (F.elastic_mask, tv_tensors.Mask), (F.elastic_video, tv_tensors.Video), - (F.elastic_keypoints, tv_tensors.KeyPoints), ], ) def test_functional_signature(self, kernel, input_type): @@ -4414,7 +4414,6 @@ def test_functional(self, make_input): (F.resized_crop_image, tv_tensors.Image), (F.resized_crop_mask, tv_tensors.Mask), (F.resized_crop_video, tv_tensors.Video), - (F.resized_crop_keypoints, tv_tensors.KeyPoints), ], ) def test_functional_signature(self, kernel, input_type): @@ -5325,6 +5324,7 @@ def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, fo def _reference_perspective_keypoints(self, keypoints, *, startpoints, endpoints): canvas_size = keypoints.canvas_size + clamping_mode = keypoints.clamping_mode dtype = keypoints.dtype device = keypoints.device @@ -5361,16 +5361,16 @@ def perspective_keypoints(keypoints): ) # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 - return F.clamp_keypoints( - output, - canvas_size=canvas_size, - ).to(dtype=dtype, device=device) + return F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode).to( + dtype=dtype, device=device + ) return tv_tensors.KeyPoints( torch.cat([perspective_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape( keypoints.shape ), canvas_size=canvas_size, + clamping_mode=clamping_mode, ) @pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS) @@ -5733,32 +5733,85 @@ def test_error(self): class TestClampKeyPoints: + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel(self, dtype, device): - keypoints = make_keypoints(dtype=dtype, device=device) + def test_kernel(self, clamping_mode, dtype, device): + keypoints = make_keypoints(dtype=dtype, device=device, clamping_mode=clamping_mode) check_kernel( F.clamp_keypoints, keypoints, canvas_size=keypoints.canvas_size, + clamping_mode=clamping_mode, ) - def test_functional(self): - check_functional(F.clamp_keypoints, make_keypoints()) + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None)) + def test_functional(self, clamping_mode): + check_functional(F.clamp_keypoints, make_keypoints(clamping_mode=clamping_mode)) def test_errors(self): input_tv_tensor = make_keypoints() input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor) - with pytest.raises(ValueError, match="`canvas_size` has to be passed"): + with pytest.raises(ValueError, match="`canvas_size` and `clamping_mode` have to be passed."): F.clamp_keypoints(input_pure_tensor, canvas_size=None) with pytest.raises(ValueError, match="`canvas_size` must not be passed"): F.clamp_keypoints(input_tv_tensor, canvas_size=input_tv_tensor.canvas_size) + with pytest.raises(ValueError, match="clamping_mode must be soft,"): + F.clamp_keypoints(input_tv_tensor, clamping_mode="bad") + with pytest.raises(ValueError, match="clamping_mode must be soft,"): + transforms.ClampKeyPoints(clamping_mode="bad")(input_tv_tensor) def test_transform(self): check_transform(transforms.ClampKeyPoints(), make_keypoints()) + @pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None)) + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None, "auto")) + @pytest.mark.parametrize("pass_pure_tensor", (True, False)) + @pytest.mark.parametrize("fn", [F.clamp_keypoints, transform_cls_to_functional(transforms.ClampKeyPoints)]) + def test_clamping_mode(self, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn): + # This test checks 2 things: + # - That passing clamping_mode=None to the clamp_keypointss + # functional (or to the class) relies on the box's `.clamping_mode` + # attribute + # - That clamping happens when it should, and only when it should, i.e. + # when the clamping mode is not None. It doesn't validate the + # numerical results, only that clamping happened. For that, we create + # a keypoints with large coordinates (100) inside of a small 10x10 image. + + if pass_pure_tensor and fn is not F.clamp_keypoints: + # Only the functional supports pure tensors, not the class + return + if pass_pure_tensor and clamping_mode == "auto": + # cannot leave clamping_mode="auto" when passing pure tensor + return + + keypoints = tv_tensors.KeyPoints( + [[0, 100], [0, 100]], canvas_size=(10, 10), clamping_mode=constructor_clamping_mode + ) + expected_clamped_output = ( + torch.tensor([[0, 10], [0, 10]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]]) + ) + expected_clamped_output = ( + torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]]) + ) + + if pass_pure_tensor: + out = fn( + keypoints.as_subclass(torch.Tensor), + canvas_size=keypoints.canvas_size, + clamping_mode=clamping_mode, + ) + else: + out = fn(keypoints, clamping_mode=clamping_mode) + + clamping_mode_prevailing = constructor_clamping_mode if clamping_mode == "auto" else clamping_mode + if clamping_mode_prevailing is None: + assert_equal(keypoints, out) # should be a pass-through + else: + assert_equal(out, expected_clamped_output) + class TestInvert: @pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.float32]) diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index 39f223f0398..c23da1a36bc 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -2,7 +2,7 @@ from torchvision import tv_tensors from torchvision.transforms.v2 import functional as F, Transform -from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE +from torchvision.tv_tensors import CLAMPING_MODE_TYPE class ConvertBoundingBoxFormat(Transform): @@ -46,17 +46,27 @@ def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> t class ClampKeyPoints(Transform): """Clamp keypoints to their corresponding image dimensions. - The clamping is done according to the keypoints' ``canvas_size`` meta-data. + Args: + clamping_mode: Default is "auto" which relies on the input keypoint' + ``clamping_mode`` attribute. + The clamping is done according to the keypoints' ``canvas_size`` meta-data. + Read more in :ref:`clamping_mode_tuto` + for more details on how to use this transform. + """ + def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None: + super().__init__() + self.clamping_mode = clamping_mode + _transformed_types = (tv_tensors.KeyPoints,) def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints: - return F.clamp_keypoints(inpt) # type: ignore[return-value] + return F.clamp_keypoints(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value] class SetClampingMode(Transform): - """Sets the ``clamping_mode`` attribute of the bounding boxes for future transforms. + """Sets the ``clamping_mode`` attribute of the bounding boxes and keypoints for future transforms. @@ -73,9 +83,10 @@ def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None: if self.clamping_mode not in (None, "soft", "hard"): raise ValueError(f"clamping_mode must be soft, hard or None, got {clamping_mode}") - _transformed_types = (tv_tensors.BoundingBoxes,) + _transformed_types = (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints) def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes: + # this method works for both `tv_tensors.BoundingBoxes`` and `tv_tensors.KeyPoints`. out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment] out.clamping_mode = self.clamping_mode return out diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0c7eab0c04e..0f1acdd887b 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -20,7 +20,7 @@ pil_to_tensor, to_pil_image, ) -from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE +from torchvision.tv_tensors import CLAMPING_MODE_TYPE from torchvision.utils import _log_api_usage_once @@ -67,16 +67,20 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(mask) -def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]): +def horizontal_flip_keypoints( + keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft" +): shape = keypoints.shape keypoints = keypoints.clone().reshape(-1, 2) keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_() - return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size) + return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size, clamping_mode=clamping_mode) @_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _horizontal_flip_keypoints_dispatch(keypoints: tv_tensors.KeyPoints): - out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size) + out = horizontal_flip_keypoints( + keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size, clamping_mode=keypoints.clamping_mode + ) return tv_tensors.wrap(out, like=keypoints) @@ -155,11 +159,15 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image(mask) -def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor: +def vertical_flip_keypoints( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", +) -> torch.Tensor: shape = keypoints.shape keypoints = keypoints.clone().reshape(-1, 2) keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_() - return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size) + return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size, clamping_mode=clamping_mode) def vertical_flip_bounding_boxes( @@ -199,7 +207,9 @@ def vertical_flip_bounding_boxes( @_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _vertical_flip_keypoints_dispatch(inpt: tv_tensors.KeyPoints) -> tv_tensors.KeyPoints: - output = vertical_flip_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size) + output = vertical_flip_keypoints( + inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=inpt.clamping_mode + ) return tv_tensors.wrap(output, like=inpt) @@ -968,6 +978,7 @@ def _affine_keypoints_with_expand( shear: list[float], center: Optional[list[float]] = None, expand: bool = False, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: if keypoints.numel() == 0: return keypoints, canvas_size @@ -1026,7 +1037,9 @@ def _affine_keypoints_with_expand( new_width, new_height = _compute_affine_output_size(affine_vector, width, height) canvas_size = (new_height, new_width) - out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size).reshape(original_shape) + out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size, clamping_mode=clamping_mode).reshape( + original_shape + ) out_keypoints = out_keypoints.to(original_dtype) return out_keypoints, canvas_size @@ -1040,6 +1053,7 @@ def affine_keypoints( scale: float, shear: list[float], center: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ): return _affine_keypoints_with_expand( keypoints=keypoints, @@ -1050,6 +1064,7 @@ def affine_keypoints( shear=shear, center=center, expand=False, + clamping_mode=clamping_mode, ) @@ -1071,6 +1086,7 @@ def _affine_keypoints_dispatch( scale=scale, shear=shear, center=center, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1393,6 +1409,7 @@ def rotate_keypoints( angle: float, expand: bool = False, center: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: return _affine_keypoints_with_expand( keypoints=keypoints, @@ -1403,6 +1420,7 @@ def rotate_keypoints( shear=[0.0, 0.0], center=center, expand=expand, + clamping_mode=clamping_mode, ) @@ -1411,7 +1429,12 @@ def _rotate_keypoints_dispatch( inpt: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[list[float]] = None, **kwargs ) -> tv_tensors.KeyPoints: output, canvas_size = rotate_keypoints( - inpt, canvas_size=inpt.canvas_size, angle=angle, center=center, expand=expand + inpt, + canvas_size=inpt.canvas_size, + angle=angle, + center=center, + expand=expand, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1683,7 +1706,11 @@ def pad_mask( def pad_keypoints( - keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant" + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + padding: list[int], + padding_mode: str = "constant", + clamping_mode: CLAMPING_MODE_TYPE = "soft", ): SUPPORTED_MODES = ["constant"] if padding_mode not in SUPPORTED_MODES: @@ -1695,20 +1722,21 @@ def pad_keypoints( left, right, top, bottom = _parse_pad_padding(padding) pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device) canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right) - return clamp_keypoints(keypoints + pad, canvas_size), canvas_size + return clamp_keypoints(keypoints + pad, canvas_size=canvas_size, clamping_mode=clamping_mode), canvas_size @_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _pad_keypoints_dispatch( - keypoints: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs + inpt: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs ) -> tv_tensors.KeyPoints: output, canvas_size = pad_keypoints( - keypoints.as_subclass(torch.Tensor), - canvas_size=keypoints.canvas_size, + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, padding=padding, padding_mode=padding_mode, + clamping_mode=inpt.clamping_mode, ) - return tv_tensors.wrap(output, like=keypoints, canvas_size=canvas_size) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) def pad_bounding_boxes( @@ -1812,19 +1840,22 @@ def crop_keypoints( left: int, height: int, width: int, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: keypoints = keypoints - torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device) canvas_size = (height, width) - return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size + return clamp_keypoints(keypoints, canvas_size=canvas_size, clamping_mode=clamping_mode), canvas_size @_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _crop_keypoints_dispatch( inpt: tv_tensors.KeyPoints, top: int, left: int, height: int, width: int ) -> tv_tensors.KeyPoints: - output, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width) + output, canvas_size = crop_keypoints( + inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, clamping_mode=inpt.clamping_mode + ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2024,6 +2055,7 @@ def perspective_keypoints( startpoints: Optional[list[list[int]]], endpoints: Optional[list[list[int]]], coefficients: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ): if keypoints.numel() == 0: return keypoints @@ -2047,7 +2079,9 @@ def perspective_keypoints( numer_points = torch.matmul(points, theta1.T) denom_points = torch.matmul(points, theta2.T) transformed_points = numer_points.div_(denom_points) - return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size).reshape(original_shape) + return clamp_keypoints( + transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode + ).reshape(original_shape) @_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False) @@ -2064,6 +2098,7 @@ def _perspective_keypoints_dispatch( startpoints=startpoints, endpoints=endpoints, coefficients=coefficients, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt) @@ -2344,7 +2379,10 @@ def _create_identity_grid(size: tuple[int, int], device: torch.device, dtype: to def elastic_keypoints( - keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + displacement: torch.Tensor, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: expected_shape = (1, canvas_size[0], canvas_size[1], 2) if not isinstance(displacement, torch.Tensor): @@ -2376,12 +2414,19 @@ def elastic_keypoints( t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype) transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) - return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size).reshape(original_shape) + return clamp_keypoints( + transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode + ).reshape(original_shape) @_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _elastic_keypoints_dispatch(inpt: tv_tensors.KeyPoints, displacement: torch.Tensor, **kwargs): - output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement) + output = elastic_keypoints( + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, + displacement=displacement, + clamping_mode=inpt.clamping_mode, + ) return tv_tensors.wrap(output, like=inpt) @@ -2578,16 +2623,26 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: list[int]) -> PI return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) -def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int]): +def center_crop_keypoints( + inpt: torch.Tensor, + canvas_size: tuple[int, int], + output_size: list[int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", +): crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) - return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width) + return crop_keypoints( + inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width, clamping_mode=clamping_mode + ) @_register_kernel_internal(center_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: list[int]) -> tv_tensors.KeyPoints: output, canvas_size = center_crop_keypoints( - inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, + output_size=output_size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2745,8 +2800,9 @@ def resized_crop_keypoints( height: int, width: int, size: list[int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: - keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width) + keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width, clamping_mode=clamping_mode) return resize_keypoints(keypoints, size=size, canvas_size=canvas_size) @@ -2755,7 +2811,13 @@ def _resized_crop_keypoints_dispatch( inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs ): output, canvas_size = resized_crop_keypoints( - inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size + inpt.as_subclass(torch.Tensor), + top=top, + left=left, + height=height, + width=width, + size=size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 4568b39ab59..a0be77181b6 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -4,8 +4,7 @@ import torch from torchvision import tv_tensors from torchvision.transforms import _functional_pil as _FP -from torchvision.tv_tensors import BoundingBoxFormat -from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE +from torchvision.tv_tensors import BoundingBoxFormat, CLAMPING_MODE_TYPE from torchvision.utils import _log_api_usage_once @@ -653,7 +652,11 @@ def clamp_bounding_boxes( ) -def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor: +def _clamp_keypoints( + keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE +) -> torch.Tensor: + if clamping_mode is None or clamping_mode != "hard": + return keypoints.clone() dtype = keypoints.dtype keypoints = keypoints.clone() if keypoints.is_floating_point() else keypoints.float() # Note that max is canvas_size[i] - 1 and not can canvas_size[i] like for @@ -666,20 +669,28 @@ def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> t def clamp_keypoints( inpt: torch.Tensor, canvas_size: Optional[tuple[int, int]] = None, + clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto", ) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.ClampKeyPoints` for details.""" if not torch.jit.is_scripting(): _log_api_usage_once(clamp_keypoints) + if clamping_mode is not None and clamping_mode not in ("soft", "hard", "auto"): + raise ValueError(f"clamping_mode must be soft, hard, auto or None, got {clamping_mode}") + if torch.jit.is_scripting() or is_pure_tensor(inpt): - if canvas_size is None: - raise ValueError("For pure tensor inputs, `canvas_size` has to be passed.") - return _clamp_keypoints(inpt, canvas_size=canvas_size) + if canvas_size is None or (clamping_mode is not None and clamping_mode == "auto"): + raise ValueError("For pure tensor inputs, `canvas_size` and `clamping_mode` have to be passed.") + return _clamp_keypoints(inpt, canvas_size=canvas_size, clamping_mode=clamping_mode) elif isinstance(inpt, tv_tensors.KeyPoints): if canvas_size is not None: raise ValueError("For keypoints tv_tensor inputs, `canvas_size` must not be passed.") - output = _clamp_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size) + if clamping_mode is None and clamping_mode == "auto": + clamping_mode = inpt.clamping_mode + output = _clamp_keypoints( + inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=clamping_mode + ) return tv_tensors.wrap(output, like=inpt) else: raise TypeError(f"Input can either be a plain tensor or a keypoints tv_tensor, but got {type(inpt)} instead.") diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 744e5241135..1e6f12fb7f7 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,6 +1,6 @@ import torch -from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format +from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, CLAMPING_MODE_TYPE, is_rotated_bounding_format from ._image import Image from ._keypoints import KeyPoints from ._mask import Mask @@ -34,6 +34,10 @@ def wrap(wrappee, *, like, **kwargs): clamping_mode=kwargs.get("clamping_mode", like.clamping_mode), ) elif isinstance(like, KeyPoints): - return KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size)) + return KeyPoints._wrap( + wrappee, + canvas_size=kwargs.get("canvas_size", like.canvas_size), + clamping_mode=kwargs.get("clamping_mode", like.clamping_mode), + ) else: return wrappee.as_subclass(type(like)) diff --git a/torchvision/tv_tensors/_keypoints.py b/torchvision/tv_tensors/_keypoints.py index aede31ad7db..51633031bf6 100644 --- a/torchvision/tv_tensors/_keypoints.py +++ b/torchvision/tv_tensors/_keypoints.py @@ -5,6 +5,8 @@ import torch from torch.utils._pytree import tree_flatten +from ._bounding_boxes import CLAMPING_MODE_TYPE + from ._tv_tensor import TVTensor @@ -43,6 +45,8 @@ class KeyPoints(TVTensor): :func:`torch.as_tensor`. canvas_size (two-tuple of ints): Height and width of the corresponding image or video. + clamping_mode: The clamping mode to use when applying transforms that may result in key points + outside of the image. Possible values are: "soft", "hard", or ``None``. Read more in :ref:`clamping_mode_tuto`. dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from ``data``. device (torch.device, optional): Desired device of the bounding box. If @@ -55,16 +59,20 @@ class KeyPoints(TVTensor): """ canvas_size: tuple[int, int] + clamping_mode: CLAMPING_MODE_TYPE @classmethod - def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], check_dims: bool = True) -> KeyPoints: # type: ignore[override] + def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft", check_dims: bool = True) -> KeyPoints: # type: ignore[override] if check_dims: if tensor.ndim == 1: tensor = tensor.unsqueeze(0) elif tensor.shape[-1] != 2: raise ValueError(f"Expected a tensor of shape (..., 2), not {tensor.shape}") + if clamping_mode is not None and clamping_mode not in ("hard", "soft"): + raise ValueError(f"clamping_mode must be None, hard or soft, got {clamping_mode}.") points = tensor.as_subclass(cls) points.canvas_size = canvas_size + points.clamping_mode = clamping_mode return points def __new__( @@ -72,12 +80,13 @@ def __new__( data: Any, *, canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", dtype: torch.dtype | None = None, device: torch.device | str | int | None = None, requires_grad: bool | None = None, ) -> KeyPoints: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return cls._wrap(tensor, canvas_size=canvas_size) + return cls._wrap(tensor, canvas_size=canvas_size, clamping_mode=clamping_mode) @classmethod def _wrap_output( @@ -89,14 +98,17 @@ def _wrap_output( # Similar to BoundingBoxes._wrap_output(), see comment there. flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] first_keypoints_from_args = next(x for x in flat_params if isinstance(x, KeyPoints)) - canvas_size = first_keypoints_from_args.canvas_size + canvas_size, clamping_mode = first_keypoints_from_args.canvas_size, first_keypoints_from_args.clamping_mode if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints): - output = KeyPoints._wrap(output, canvas_size=canvas_size, check_dims=False) + output = KeyPoints._wrap(output, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) elif isinstance(output, (tuple, list)): # This branch exists for chunk() and unbind() - output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, check_dims=False) for part in output) + output = type(output)( + KeyPoints._wrap(part, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) + for part in output + ) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(canvas_size=self.canvas_size) + return self._make_repr(canvas_size=self.canvas_size, clamping_mode=self.clamping_mode)