Skip to content

Commit c92bc32

Browse files
add clamping_mode parameter to KeyPoints constructor
1 parent 58eb039 commit c92bc32

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

torchvision/tv_tensors/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format
3+
from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format, CLAMPING_MODE_TYPE
44
from ._image import Image
55
from ._keypoints import KeyPoints
66
from ._mask import Mask
@@ -34,6 +34,6 @@ def wrap(wrappee, *, like, **kwargs):
3434
clamping_mode=kwargs.get("clamping_mode", like.clamping_mode),
3535
)
3636
elif isinstance(like, KeyPoints):
37-
return KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size))
37+
return KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size), clamping_mode=kwargs.get("clamping_mode", like.clamping_mode))
3838
else:
3939
return wrappee.as_subclass(type(like))

torchvision/tv_tensors/_keypoints.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.utils._pytree import tree_flatten
77

88
from ._tv_tensor import TVTensor
9+
from ._bounding_boxes import CLAMPING_MODE_TYPE
910

1011

1112
class KeyPoints(TVTensor):
@@ -43,6 +44,8 @@ class KeyPoints(TVTensor):
4344
:func:`torch.as_tensor`.
4445
canvas_size (two-tuple of ints): Height and width of the corresponding
4546
image or video.
47+
clamping_mode: The clamping mode to use when applying transforms that may result in key points
48+
outside of the image. Possible values are: "soft", "hard", or ``None``. Read more in :ref:`clamping_mode_tuto`.
4649
dtype (torch.dtype, optional): Desired data type of the bounding box. If
4750
omitted, will be inferred from ``data``.
4851
device (torch.device, optional): Desired device of the bounding box. If
@@ -55,29 +58,34 @@ class KeyPoints(TVTensor):
5558
"""
5659

5760
canvas_size: tuple[int, int]
61+
clamping_mode: CLAMPING_MODE_TYPE
5862

5963
@classmethod
60-
def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], check_dims: bool = True) -> KeyPoints: # type: ignore[override]
64+
def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft", check_dims: bool = True) -> KeyPoints: # type: ignore[override]
6165
if check_dims:
6266
if tensor.ndim == 1:
6367
tensor = tensor.unsqueeze(0)
6468
elif tensor.shape[-1] != 2:
6569
raise ValueError(f"Expected a tensor of shape (..., 2), not {tensor.shape}")
70+
if clamping_mode is not None and clamping_mode not in ("hard", "soft"):
71+
raise ValueError(f"clamping_mode must be None, hard or soft, got {clamping_mode}.")
6672
points = tensor.as_subclass(cls)
6773
points.canvas_size = canvas_size
74+
points.clamping_mode = clamping_mode
6875
return points
6976

7077
def __new__(
7178
cls,
7279
data: Any,
7380
*,
7481
canvas_size: tuple[int, int],
82+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
7583
dtype: torch.dtype | None = None,
7684
device: torch.device | str | int | None = None,
7785
requires_grad: bool | None = None,
7886
) -> KeyPoints:
7987
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
80-
return cls._wrap(tensor, canvas_size=canvas_size)
88+
return cls._wrap(tensor, canvas_size=canvas_size, clamping_mode=clamping_mode)
8189

8290
@classmethod
8391
def _wrap_output(
@@ -89,14 +97,14 @@ def _wrap_output(
8997
# Similar to BoundingBoxes._wrap_output(), see comment there.
9098
flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator]
9199
first_keypoints_from_args = next(x for x in flat_params if isinstance(x, KeyPoints))
92-
canvas_size = first_keypoints_from_args.canvas_size
100+
canvas_size, clamping_mode = first_keypoints_from_args.canvas_size, first_keypoints_from_args.clamping_mode
93101

94102
if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints):
95-
output = KeyPoints._wrap(output, canvas_size=canvas_size, check_dims=False)
103+
output = KeyPoints._wrap(output, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False)
96104
elif isinstance(output, (tuple, list)):
97105
# This branch exists for chunk() and unbind()
98-
output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, check_dims=False) for part in output)
106+
output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) for part in output)
99107
return output
100108

101109
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
102-
return self._make_repr(canvas_size=self.canvas_size)
110+
return self._make_repr(canvas_size=self.canvas_size, clamping_mode=self.clamping_mode)

0 commit comments

Comments
 (0)