Skip to content

Conversation

Callidior
Copy link

Context

This PR proposes to add a SanitizeKeyPoints transform, similar to the existing SanitizeBoundingBoxes (#7246). This transform removes keypoints lying outside of the valid image area, which can happen after geometric transformations with the new default clamping_mode proposed in #9234, which allows for disabling automatic clamping of keypoints.

This implementation follows the proposal in issue #9223 as a solution for the issue that the previous default clamping of keypoints to the image edges modifies their position and creates a misalignment with the actual locations in the transformed image.

This PR hence only makes sense in combination with new clamping modes such as the one proposed in #9234.

Implementation details

To understand the behavior of the proposed SanitizeKeyPoints transform, we need to distinguish two cases of keypoint formats:

  • tv_tensors.KeyPoints contains a set of keypoints of shape [N_points, 2] or [N_points, 1, 2]. In this case, the transform will remove all keypoints lying outside of the valid image region.
  • tv_tensors.KeyPoints contains groups of keypoints, i.e., several objects, each consisting of a certain number of keypoints (e.g., polygons, polygonal chains, skeletons etc.). It is a tensor of shape [N_objects, N_points, 2] or, in general, [N_objects, ..., 2]. In this case, the transform will remove all objects (first dimension) that have at least a certain number of keypoints lying outside of the valid image region.

The behavior of the transform can be controlled with the following arguments:

  • min_valid_edge_distance (int): The minimum distance that keypoints need to be away from the closest image edge along any axis in order to be considered valid. For example, setting this to 0 will only invalidate/remove keypoints outside of the image area, while a value of 1 will also remove keypoints lying exactly on the edge. Default is 0.
  • min_invalid_points (int or float): Minimum number or fraction of invalid keypoints required for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints. If a float in (0.0, 1.0] is passed, it represents a fraction of the total number of keypoints in the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints. Note that a value of 1 (integer) is very different from 1.0 (float). The former will remove groups with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid. Default is 1 (int).

In addition, the transform can also remove labels associated with the keypoints (or elements from any other tensors with the same first dimension as the keypoints). This can be achieved by setting the labels_getter argument, which follows the same logic as the homonymous argument of SanitizeBoundingBoxes. The only difference is, that the default for SanitizeKeyPoints is None, in order to avoid accidental conflicts with any additionally present bounding box labels.

Illustration of the changes

The following example additionally requires PR #9234.

orig_pts = KeyPoints(
    [
        [[445, 700]],  # nose
        [[320, 660]],
        [[370, 660]],
        [[420, 660]],  # left eye
        [[300, 620]],
        [[420, 620]],  # left eyebrow
        [[475, 665]],
        [[515, 665]],
        [[555, 655]],  # right eye
        [[460, 625]],
        [[560, 600]],  # right eyebrow
        [[370, 780]],
        [[450, 760]],
        [[540, 780]],
        [[450, 820]],  # mouth
    ],
    canvas_size=(orig_img.size[1], orig_img.size[0]),
    clamping_mode="soft",
)
cropper = v2.RandomCrop(size=(256, 256))
crops = [cropper((orig_img, orig_pts)) for _ in range(4)]
plot([(orig_img, orig_pts)] + crops)
sanitize-keypoints-example

Unsanitized keypoint coordinates:

for _, pts in crops:
    print(pts)
KeyPoints([[[   1, -109]],
           [[-124, -149]],
           [[ -74, -149]],
           [[ -24, -149]],
           [[-144, -189]],
           [[ -24, -189]],
           [[  31, -144]],
           [[  71, -144]],
           [[ 111, -154]],
           [[  16, -184]],
           [[ 116, -209]],
           [[ -74,  -29]],
           [[   6,  -49]],
           [[  96,  -29]],
           [[   6,   11]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([[[ -65, -238]],
           [[-190, -278]],
           [[-140, -278]],
           [[ -90, -278]],
           [[-210, -318]],
           [[ -90, -318]],
           [[ -35, -273]],
           [[   5, -273]],
           [[  45, -283]],
           [[ -50, -313]],
           [[  50, -338]],
           [[-140, -158]],
           [[ -60, -178]],
           [[  30, -158]],
           [[ -60, -118]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([[[301,  27]],
           [[176, -13]],
           [[226, -13]],
           [[276, -13]],
           [[156, -53]],
           [[276, -53]],
           [[331,  -8]],
           [[371,  -8]],
           [[411, -18]],
           [[316, -48]],
           [[416, -73]],
           [[226, 107]],
           [[306,  87]],
           [[396, 107]],
           [[306, 147]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([[[ 372,  -27]],
           [[ 247,  -67]],
           [[ 297,  -67]],
           [[ 347,  -67]],
           [[ 227, -107]],
           [[ 347, -107]],
           [[ 402,  -62]],
           [[ 442,  -62]],
           [[ 482,  -72]],
           [[ 387, -102]],
           [[ 487, -127]],
           [[ 297,   53]],
           [[ 377,   33]],
           [[ 467,   53]],
           [[ 377,   93]]], canvas_size=(256, 256), clamping_mode=soft)

Sanitization:

sanitizer = v2.SanitizeKeyPoints()
sanitized_pts = [sanitizer(pts) for _, pts in crops]

for pts in sanitized_pts:
    print(pts)
KeyPoints([[[ 6, 11]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([], size=(0, 1, 2), dtype=torch.int64, canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([[[226, 107]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([], size=(0, 1, 2), dtype=torch.int64, canvas_size=(256, 256), clamping_mode=soft)

Testing

Please run the following unit tests:

pytest test/test_transforms_v2.py -vvv -k "SanitizeKeyPoints"
...
219 passed, 9718 deselected, 8 xfailed in 1.27s

Copy link

pytorch-bot bot commented Oct 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9235

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant