Skip to content

Commit 211acf2

Browse files
committed
Add ClampKeyPoints and corresponding test
1 parent 651d172 commit 211acf2

File tree

6 files changed

+49
-7
lines changed

6 files changed

+49
-7
lines changed

docs/source/transforms.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ Miscellaneous
408408
v2.Lambda
409409
v2.SanitizeBoundingBoxes
410410
v2.ClampBoundingBoxes
411+
v2.ClampKeyPoints
411412
v2.UniformTemporalSubsample
412413
v2.JPEG
413414

@@ -421,6 +422,7 @@ Functionals
421422
v2.functional.erase
422423
v2.functional.sanitize_bounding_boxes
423424
v2.functional.clamp_bounding_boxes
425+
v2.functional.clamp_keypoints
424426
v2.functional.uniform_temporal_subsample
425427
v2.functional.jpeg
426428

test/test_transforms_v2.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5444,6 +5444,34 @@ def test_errors(self):
54445444
def test_transform(self):
54455445
check_transform(transforms.ClampBoundingBoxes(), make_bounding_boxes())
54465446

5447+
class TestClampKeyPoints:
5448+
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
5449+
@pytest.mark.parametrize("device", cpu_and_cuda())
5450+
def test_kernel(self, dtype, device):
5451+
keypoints = make_keypoints(dtype=dtype, device=device)
5452+
check_kernel(
5453+
F.clamp_keypoints,
5454+
keypoints,
5455+
canvas_size=keypoints.canvas_size,
5456+
)
5457+
5458+
def test_functional(self):
5459+
check_functional(F.clamp_keypoints, make_keypoints())
5460+
5461+
def test_errors(self):
5462+
input_tv_tensor = make_keypoints()
5463+
input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)
5464+
5465+
with pytest.raises(ValueError, match="`canvas_size` has to be passed"):
5466+
F.clamp_keypoints(input_pure_tensor, canvas_size=None)
5467+
5468+
with pytest.raises(ValueError, match="`canvas_size` must not be passed"):
5469+
F.clamp_keypoints(input_tv_tensor, canvas_size=input_tv_tensor.canvas_size)
5470+
5471+
def test_transform(self):
5472+
check_transform(transforms.ClampKeyPoints(), make_keypoints())
5473+
5474+
54475475

54485476
class TestInvert:
54495477
@pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.float32])

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
ScaleJitter,
4242
TenCrop,
4343
)
44-
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat
44+
from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat
4545
from ._misc import (
4646
ConvertImageDtype,
4747
GaussianBlur,

torchvision/transforms/v2/_meta.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,15 @@ class ClampBoundingBoxes(Transform):
3434

3535
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
3636
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]
37+
38+
class ClampKeyPoints(Transform):
39+
"""Clamp keypoints to their corresponding image dimensions.
40+
41+
The clamping is done according to the keypoints' ``canvas_size`` meta-data.
42+
43+
"""
44+
45+
_transformed_types = (tv_tensors.KeyPoints,)
46+
47+
def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints:
48+
return F.clamp_keypoints(inpt) # type: ignore[return-value]

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
clamp_bounding_boxes,
77
clamp_keypoints,
88
convert_bounding_box_format,
9-
convert_bounding_boxes_to_points,
9+
convert_bounding_boxes_to_points, #TODOKP also needs docs
1010
get_dimensions_image,
1111
get_dimensions_video,
1212
get_dimensions,
@@ -157,7 +157,6 @@
157157
normalize_image,
158158
normalize_video,
159159
sanitize_bounding_boxes,
160-
sanitize_keypoints,
161160
to_dtype,
162161
to_dtype_image,
163162
to_dtype_video,

torchvision/transforms/v2/functional/_meta.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,13 @@ def clamp_bounding_boxes(
457457
def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor:
458458
dtype = keypoints.dtype
459459
keypoints = keypoints.clone() if keypoints.is_floating_point() else keypoints.float()
460-
keypoints[..., 0].clamp_(min=0, max=canvas_size[1])
461-
keypoints[..., 1].clamp_(min=0, max=canvas_size[0])
460+
# Note that max is canvas_size[i] - 1 and not can canvas_size[i] like for
461+
# bounding boxes.
462+
keypoints[..., 0].clamp_(min=0, max=canvas_size[1] - 1)
463+
keypoints[..., 1].clamp_(min=0, max=canvas_size[0] - 1)
462464
return keypoints.to(dtype=dtype)
463465

464466

465-
# TODOKP there is no corresponding transform and this isn't tested
466467
def clamp_keypoints(
467468
inpt: torch.Tensor,
468469
canvas_size: Optional[tuple[int, int]] = None,
@@ -473,7 +474,7 @@ def clamp_keypoints(
473474
if torch.jit.is_scripting() or is_pure_tensor(inpt):
474475

475476
if canvas_size is None:
476-
raise ValueError("For pure tensor inputs, `canvas_size` have to be passed.")
477+
raise ValueError("For pure tensor inputs, `canvas_size` has to be passed.")
477478
return _clamp_keypoints(inpt, canvas_size=canvas_size)
478479
elif isinstance(inpt, tv_tensors.KeyPoints):
479480
if canvas_size is not None:

0 commit comments

Comments
 (0)