Skip to content

Commit fe51080

Browse files
add clamping_mode parameter to clamp_keypoints functional and class
1 parent c92bc32 commit fe51080

File tree

3 files changed

+67
-39
lines changed

3 files changed

+67
-39
lines changed

torchvision/transforms/v2/_meta.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torchvision import tv_tensors
44
from torchvision.transforms.v2 import functional as F, Transform
5-
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
5+
from torchvision.tv_tensors import CLAMPING_MODE_TYPE
66

77

88
class ConvertBoundingBoxFormat(Transform):
@@ -46,17 +46,26 @@ def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> t
4646
class ClampKeyPoints(Transform):
4747
"""Clamp keypoints to their corresponding image dimensions.
4848
49-
The clamping is done according to the keypoints' ``canvas_size`` meta-data.
49+
Args:
50+
clamping_mode: Default is "auto" which relies on the input keypoint'
51+
``clamping_mode`` attribute.
52+
The clamping is done according to the keypoints' ``canvas_size`` meta-data.
53+
Read more in :ref:`clamping_mode_tuto`
54+
for more details on how to use this transform.
55+
5056
"""
57+
def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None:
58+
super().__init__()
59+
self.clamping_mode = clamping_mode
5160

5261
_transformed_types = (tv_tensors.KeyPoints,)
5362

5463
def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints:
55-
return F.clamp_keypoints(inpt) # type: ignore[return-value]
64+
return F.clamp_keypoints(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value]
5665

5766

5867
class SetClampingMode(Transform):
59-
"""Sets the ``clamping_mode`` attribute of the bounding boxes for future transforms.
68+
"""Sets the ``clamping_mode`` attribute of the bounding boxes and keypoints for future transforms.
6069
6170
6271
@@ -73,9 +82,9 @@ def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None:
7382
if self.clamping_mode not in (None, "soft", "hard"):
7483
raise ValueError(f"clamping_mode must be soft, hard or None, got {clamping_mode}")
7584

76-
_transformed_types = (tv_tensors.BoundingBoxes,)
85+
_transformed_types = (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints)
7786

78-
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
79-
out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment]
87+
def transform(self, inpt: tv_tensors.TVTensor, params: dict[str, Any]) -> tv_tensors.TVTensor:
88+
out: tv_tensors.TVTensor = inpt.clone() # type: ignore[assignment]
8089
out.clamping_mode = self.clamping_mode
8190
return out

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
pil_to_tensor,
2121
to_pil_image,
2222
)
23-
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
23+
from torchvision.tv_tensors import CLAMPING_MODE_TYPE
2424

2525
from torchvision.utils import _log_api_usage_once
2626

@@ -67,16 +67,16 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
6767
return horizontal_flip_image(mask)
6868

6969

70-
def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]):
70+
def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft"):
7171
shape = keypoints.shape
7272
keypoints = keypoints.clone().reshape(-1, 2)
7373
keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_()
74-
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
74+
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size, clamping_mode=clamping_mode)
7575

7676

7777
@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
7878
def _horizontal_flip_keypoints_dispatch(keypoints: tv_tensors.KeyPoints):
79-
out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size)
79+
out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size, clamping_mode=keypoints.clamping_mode)
8080
return tv_tensors.wrap(out, like=keypoints)
8181

8282

@@ -155,11 +155,11 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
155155
return vertical_flip_image(mask)
156156

157157

158-
def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor:
158+
def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft",) -> torch.Tensor:
159159
shape = keypoints.shape
160160
keypoints = keypoints.clone().reshape(-1, 2)
161161
keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_()
162-
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
162+
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size, clamping_mode=clamping_mode)
163163

164164

165165
def vertical_flip_bounding_boxes(
@@ -199,7 +199,7 @@ def vertical_flip_bounding_boxes(
199199

200200
@_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
201201
def _vertical_flip_keypoints_dispatch(inpt: tv_tensors.KeyPoints) -> tv_tensors.KeyPoints:
202-
output = vertical_flip_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size)
202+
output = vertical_flip_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=inpt.clamping_mode)
203203
return tv_tensors.wrap(output, like=inpt)
204204

205205

@@ -968,6 +968,7 @@ def _affine_keypoints_with_expand(
968968
shear: list[float],
969969
center: Optional[list[float]] = None,
970970
expand: bool = False,
971+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
971972
) -> tuple[torch.Tensor, tuple[int, int]]:
972973
if keypoints.numel() == 0:
973974
return keypoints, canvas_size
@@ -1026,7 +1027,7 @@ def _affine_keypoints_with_expand(
10261027
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
10271028
canvas_size = (new_height, new_width)
10281029

1029-
out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size).reshape(original_shape)
1030+
out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape)
10301031
out_keypoints = out_keypoints.to(original_dtype)
10311032

10321033
return out_keypoints, canvas_size
@@ -1040,6 +1041,7 @@ def affine_keypoints(
10401041
scale: float,
10411042
shear: list[float],
10421043
center: Optional[list[float]] = None,
1044+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
10431045
):
10441046
return _affine_keypoints_with_expand(
10451047
keypoints=keypoints,
@@ -1050,6 +1052,7 @@ def affine_keypoints(
10501052
shear=shear,
10511053
center=center,
10521054
expand=False,
1055+
clamping_mode=clamping_mode,
10531056
)
10541057

10551058

@@ -1071,6 +1074,7 @@ def _affine_keypoints_dispatch(
10711074
scale=scale,
10721075
shear=shear,
10731076
center=center,
1077+
clamping_mode=inpt.clamping_mode,
10741078
)
10751079
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
10761080

@@ -1393,6 +1397,7 @@ def rotate_keypoints(
13931397
angle: float,
13941398
expand: bool = False,
13951399
center: Optional[list[float]] = None,
1400+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
13961401
) -> tuple[torch.Tensor, tuple[int, int]]:
13971402
return _affine_keypoints_with_expand(
13981403
keypoints=keypoints,
@@ -1403,6 +1408,7 @@ def rotate_keypoints(
14031408
shear=[0.0, 0.0],
14041409
center=center,
14051410
expand=expand,
1411+
clamping_mode=clamping_mode,
14061412
)
14071413

14081414

@@ -1411,7 +1417,7 @@ def _rotate_keypoints_dispatch(
14111417
inpt: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[list[float]] = None, **kwargs
14121418
) -> tv_tensors.KeyPoints:
14131419
output, canvas_size = rotate_keypoints(
1414-
inpt, canvas_size=inpt.canvas_size, angle=angle, center=center, expand=expand
1420+
inpt, canvas_size=inpt.canvas_size, angle=angle, center=center, expand=expand, clamping_mode=inpt.clamping_mode,
14151421
)
14161422
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
14171423

@@ -1683,7 +1689,7 @@ def pad_mask(
16831689

16841690

16851691
def pad_keypoints(
1686-
keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant"
1692+
keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant", clamping_mode: CLAMPING_MODE_TYPE = "soft"
16871693
):
16881694
SUPPORTED_MODES = ["constant"]
16891695
if padding_mode not in SUPPORTED_MODES:
@@ -1695,20 +1701,21 @@ def pad_keypoints(
16951701
left, right, top, bottom = _parse_pad_padding(padding)
16961702
pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
16971703
canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right)
1698-
return clamp_keypoints(keypoints + pad, canvas_size), canvas_size
1704+
return clamp_keypoints(keypoints + pad, canvas_size=canvas_size, clamping_mode=clamping_mode), canvas_size
16991705

17001706

17011707
@_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
17021708
def _pad_keypoints_dispatch(
1703-
keypoints: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs
1709+
inpt: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs
17041710
) -> tv_tensors.KeyPoints:
17051711
output, canvas_size = pad_keypoints(
1706-
keypoints.as_subclass(torch.Tensor),
1707-
canvas_size=keypoints.canvas_size,
1712+
inpt.as_subclass(torch.Tensor),
1713+
canvas_size=inpt.canvas_size,
17081714
padding=padding,
17091715
padding_mode=padding_mode,
1716+
clamping_mode=inpt.clamping_mode,
17101717
)
1711-
return tv_tensors.wrap(output, like=keypoints, canvas_size=canvas_size)
1718+
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
17121719

17131720

17141721
def pad_bounding_boxes(
@@ -1812,19 +1819,20 @@ def crop_keypoints(
18121819
left: int,
18131820
height: int,
18141821
width: int,
1822+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
18151823
) -> tuple[torch.Tensor, tuple[int, int]]:
18161824

18171825
keypoints = keypoints - torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
18181826
canvas_size = (height, width)
18191827

1820-
return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size
1828+
return clamp_keypoints(keypoints, canvas_size=canvas_size, clamping_mode=clamping_mode), canvas_size
18211829

18221830

18231831
@_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
18241832
def _crop_keypoints_dispatch(
18251833
inpt: tv_tensors.KeyPoints, top: int, left: int, height: int, width: int
18261834
) -> tv_tensors.KeyPoints:
1827-
output, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width)
1835+
output, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, clamping_mode=inpt.clamping_mode)
18281836
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
18291837

18301838

@@ -2024,6 +2032,7 @@ def perspective_keypoints(
20242032
startpoints: Optional[list[list[int]]],
20252033
endpoints: Optional[list[list[int]]],
20262034
coefficients: Optional[list[float]] = None,
2035+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
20272036
):
20282037
if keypoints.numel() == 0:
20292038
return keypoints
@@ -2047,7 +2056,7 @@ def perspective_keypoints(
20472056
numer_points = torch.matmul(points, theta1.T)
20482057
denom_points = torch.matmul(points, theta2.T)
20492058
transformed_points = numer_points.div_(denom_points)
2050-
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size).reshape(original_shape)
2059+
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape)
20512060

20522061

20532062
@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -2064,6 +2073,7 @@ def _perspective_keypoints_dispatch(
20642073
startpoints=startpoints,
20652074
endpoints=endpoints,
20662075
coefficients=coefficients,
2076+
clamping_mode=inpt.clamping_mode,
20672077
)
20682078
return tv_tensors.wrap(output, like=inpt)
20692079

@@ -2344,7 +2354,7 @@ def _create_identity_grid(size: tuple[int, int], device: torch.device, dtype: to
23442354

23452355

23462356
def elastic_keypoints(
2347-
keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor
2357+
keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor, clamping_mode: CLAMPING_MODE_TYPE = "soft",
23482358
) -> torch.Tensor:
23492359
expected_shape = (1, canvas_size[0], canvas_size[1], 2)
23502360
if not isinstance(displacement, torch.Tensor):
@@ -2376,12 +2386,12 @@ def elastic_keypoints(
23762386
t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
23772387
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
23782388

2379-
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size).reshape(original_shape)
2389+
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape)
23802390

23812391

23822392
@_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
23832393
def _elastic_keypoints_dispatch(inpt: tv_tensors.KeyPoints, displacement: torch.Tensor, **kwargs):
2384-
output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement)
2394+
output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement, clamping_mode=inpt.clamping_mode)
23852395
return tv_tensors.wrap(output, like=inpt)
23862396

23872397

@@ -2578,16 +2588,16 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: list[int]) -> PI
25782588
return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
25792589

25802590

2581-
def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int]):
2591+
def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int], clamping_mode: CLAMPING_MODE_TYPE = "soft",):
25822592
crop_height, crop_width = _center_crop_parse_output_size(output_size)
25832593
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
2584-
return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
2594+
return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width, clamping_mode=clamping_mode)
25852595

25862596

25872597
@_register_kernel_internal(center_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
25882598
def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: list[int]) -> tv_tensors.KeyPoints:
25892599
output, canvas_size = center_crop_keypoints(
2590-
inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size
2600+
inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size, clamping_mode=inpt.clamping_mode,
25912601
)
25922602
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
25932603

@@ -2745,8 +2755,9 @@ def resized_crop_keypoints(
27452755
height: int,
27462756
width: int,
27472757
size: list[int],
2758+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
27482759
) -> tuple[torch.Tensor, tuple[int, int]]:
2749-
keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width)
2760+
keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width, clamping_mode=clamping_mode)
27502761
return resize_keypoints(keypoints, size=size, canvas_size=canvas_size)
27512762

27522763

@@ -2755,7 +2766,7 @@ def _resized_crop_keypoints_dispatch(
27552766
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs
27562767
):
27572768
output, canvas_size = resized_crop_keypoints(
2758-
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
2769+
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size, clamping_mode=inpt.clamping_mode,
27592770
)
27602771
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
27612772

torchvision/transforms/v2/functional/_meta.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision import tv_tensors
66
from torchvision.transforms import _functional_pil as _FP
77
from torchvision.tv_tensors import BoundingBoxFormat
8-
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
8+
from torchvision.tv_tensors import CLAMPING_MODE_TYPE
99

1010
from torchvision.utils import _log_api_usage_once
1111

@@ -653,7 +653,9 @@ def clamp_bounding_boxes(
653653
)
654654

655655

656-
def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor:
656+
def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE) -> torch.Tensor:
657+
if clamping_mode is None or clamping_mode != "hard":
658+
return keypoints.clone()
657659
dtype = keypoints.dtype
658660
keypoints = keypoints.clone() if keypoints.is_floating_point() else keypoints.float()
659661
# Note that max is canvas_size[i] - 1 and not can canvas_size[i] like for
@@ -666,20 +668,26 @@ def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> t
666668
def clamp_keypoints(
667669
inpt: torch.Tensor,
668670
canvas_size: Optional[tuple[int, int]] = None,
671+
clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto",
669672
) -> torch.Tensor:
670673
"""See :func:`~torchvision.transforms.v2.ClampKeyPoints` for details."""
671674
if not torch.jit.is_scripting():
672675
_log_api_usage_once(clamp_keypoints)
673676

677+
if clamping_mode is not None and clamping_mode not in ("soft", "hard", "auto"):
678+
raise ValueError(f"clamping_mode must be soft, hard, auto or None, got {clamping_mode}")
679+
674680
if torch.jit.is_scripting() or is_pure_tensor(inpt):
675681

676-
if canvas_size is None:
677-
raise ValueError("For pure tensor inputs, `canvas_size` has to be passed.")
678-
return _clamp_keypoints(inpt, canvas_size=canvas_size)
682+
if canvas_size is None or (clamping_mode is not None and clamping_mode == "auto"):
683+
raise ValueError("For pure tensor inputs, `canvas_size` and `clamping_mode` have to be passed.")
684+
return _clamp_keypoints(inpt, canvas_size=canvas_size, clamping_mode=clamping_mode)
679685
elif isinstance(inpt, tv_tensors.KeyPoints):
680686
if canvas_size is not None:
681687
raise ValueError("For keypoints tv_tensor inputs, `canvas_size` must not be passed.")
682-
output = _clamp_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size)
688+
if clamping_mode is None and clamping_mode == "auto":
689+
clamping_mode = inpt.clamping_mode
690+
output = _clamp_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=clamping_mode)
683691
return tv_tensors.wrap(output, like=inpt)
684692
else:
685693
raise TypeError(f"Input can either be a plain tensor or a keypoints tv_tensor, but got {type(inpt)} instead.")

0 commit comments

Comments
 (0)