Skip to content

Commit ab78346

Browse files
Simplify _get_sanitize_keypoints_mask
1 parent 46a878a commit ab78346

File tree

3 files changed

+28
-183
lines changed

3 files changed

+28
-183
lines changed

test/test_transforms_v2.py

Lines changed: 10 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -7401,24 +7401,22 @@ class TestSanitizeKeyPoints:
74017401
def _make_keypoints_with_validity(
74027402
self,
74037403
canvas_size=(100, 100),
7404-
min_valid_edge_distance=0,
7405-
min_invalid_points=1,
74067404
shape="2d", # "2d", "3d", "4d" for different keypoint shapes
74077405
):
74087406
"""Create keypoints with known validity for testing."""
74097407
canvas_h, canvas_w = canvas_size
74107408

74117409
if shape == "2d": # [N_points, 2]
74127410
keypoints_data = [
7413-
([5, 5], min_valid_edge_distance <= 5), # Valid point inside image
7414-
([canvas_w - 6, canvas_h - 6], min_valid_edge_distance <= 5), # Valid point near corner
7411+
([5, 5], True), # Valid point inside image
7412+
([canvas_w - 6, canvas_h - 6], True), # Valid point near corner
74157413
([canvas_w // 2, canvas_h // 2], True), # Valid point in center
74167414
([-1, canvas_h // 2], False), # Invalid: x < 0
74177415
([canvas_w // 2, -1], False), # Invalid: y < 0
74187416
([canvas_w, canvas_h // 2], False), # Invalid: x >= canvas_w
74197417
([canvas_w // 2, canvas_h], False), # Invalid: y >= canvas_h
7420-
([0, 0], min_valid_edge_distance <= 0), # Edge case: exactly on edge
7421-
([canvas_w - 1, canvas_h - 1], min_valid_edge_distance <= 0), # Edge case: exactly on edge
7418+
([0, 0], True), # Edge case: exactly on edge
7419+
([canvas_w - 1, canvas_h - 1], True), # Edge case: exactly on edge
74227420
]
74237421
points, validity = zip(*keypoints_data)
74247422
keypoints = torch.tensor(points, dtype=torch.float32)
@@ -7429,11 +7427,11 @@ def _make_keypoints_with_validity(
74297427
# Group 1: All points valid
74307428
([[10, 10], [20, 20], [30, 30]], True),
74317429
# Group 2: One invalid point (should be removed if min_invalid_points=1)
7432-
([[10, 10], [20, 20], [-5, 30]], min_invalid_points > 1),
7430+
([[10, 10], [20, 20], [-5, 30]], False),
74337431
# Group 3: All points invalid
74347432
([[-1, -1], [-2, -2], [-3, -3]], False),
74357433
# Group 4: Mix of valid and invalid (depends on min_invalid_points)
7436-
([[10, 10], [-1, 20], [-2, 30]], min_invalid_points > 2),
7434+
([[10, 10], [-1, 20], [-2, 30]], False),
74377435
]
74387436
groups, validity = zip(*keypoints_data)
74397437
keypoints = torch.tensor(groups, dtype=torch.float32)
@@ -7444,7 +7442,7 @@ def _make_keypoints_with_validity(
74447442
# Object 1: All bones valid
74457443
([[[10, 10], [15, 15]], [[20, 20], [25, 25]]], True),
74467444
# Object 2: One bone with invalid point
7447-
([[[10, 10], [15, 15]], [[-1, 20], [25, 25]]], min_invalid_points > 1),
7445+
([[[10, 10], [15, 15]], [[-1, 20], [25, 25]]], False),
74487446
# Object 3: All bones invalid
74497447
([[[-1, -1], [-2, -2]], [[-3, -3], [-4, -4]]], False),
74507448
]
@@ -7457,26 +7455,14 @@ def _make_keypoints_with_validity(
74577455
return keypoints, validity
74587456

74597457
@pytest.mark.parametrize("shape", ["2d", "3d", "4d"])
7460-
@pytest.mark.parametrize("min_valid_edge_distance", [0, 1, 5, 6])
7461-
@pytest.mark.parametrize("min_invalid_points", [1, 2, 0.5])
74627458
@pytest.mark.parametrize("input_type", [torch.Tensor, tv_tensors.KeyPoints])
7463-
def test_functional(self, shape, min_valid_edge_distance, min_invalid_points, input_type):
7459+
def test_functional(self, shape, input_type):
74647460
"""Test the sanitize_keypoints functional interface."""
7465-
# Check for invalid configuration
7466-
if shape == "2d" and min_invalid_points > 1:
7467-
pytest.xfail("min_invalid_points > 1 does not make sense for 2D keypoints")
74687461

74697462
# Create inputs
74707463
canvas_size = (50, 50)
7471-
if isinstance(min_invalid_points, float):
7472-
num_groups = 4 if shape == "4d" else 3
7473-
min_invalid_points_int = math.ceil(min_invalid_points * num_groups)
7474-
else:
7475-
min_invalid_points_int = min_invalid_points
74767464
keypoints, expected_validity = self._make_keypoints_with_validity(
74777465
canvas_size=canvas_size,
7478-
min_valid_edge_distance=min_valid_edge_distance,
7479-
min_invalid_points=min_invalid_points_int,
74807466
shape=shape,
74817467
)
74827468

@@ -7490,8 +7476,6 @@ def test_functional(self, shape, min_valid_edge_distance, min_invalid_points, in
74907476
result_keypoints, valid_mask = F.sanitize_keypoints(
74917477
keypoints,
74927478
canvas_size=canvas_size_arg,
7493-
min_valid_edge_distance=min_valid_edge_distance,
7494-
min_invalid_points=min_invalid_points,
74957479
)
74967480

74977481
# Check return types
@@ -7522,8 +7506,6 @@ def test_kernel(self, shape):
75227506
)
75237507

75247508
@pytest.mark.parametrize("shape", ["2d", "3d", "4d"])
7525-
@pytest.mark.parametrize("min_valid_edge_distance", [0, 2])
7526-
@pytest.mark.parametrize("min_invalid_points", [1, 0.3])
75277509
@pytest.mark.parametrize(
75287510
"labels_getter",
75297511
(
@@ -7536,26 +7518,15 @@ def test_kernel(self, shape):
75367518
),
75377519
)
75387520
@pytest.mark.parametrize("sample_type", (tuple, dict))
7539-
def test_transform(self, shape, min_valid_edge_distance, min_invalid_points, labels_getter, sample_type):
7521+
def test_transform(self, shape, labels_getter, sample_type):
75407522
"""Test the SanitizeKeyPoints transform class."""
75417523
if sample_type is tuple and not isinstance(labels_getter, str):
75427524
# Lambda-based labels_getter doesn't work with tuple input
75437525
return
75447526

7545-
# Check for invalid configuration
7546-
if shape == "2d" and min_invalid_points > 1:
7547-
pytest.xfail("min_invalid_points > 1 does not make sense for 2D keypoints")
7548-
75497527
canvas_size = (40, 40)
7550-
if isinstance(min_invalid_points, float):
7551-
num_groups = 4 if shape == "4d" else 3
7552-
min_invalid_points_int = math.ceil(min_invalid_points * num_groups)
7553-
else:
7554-
min_invalid_points_int = min_invalid_points
75557528
keypoints, expected_validity = self._make_keypoints_with_validity(
75567529
canvas_size=canvas_size,
7557-
min_valid_edge_distance=min_valid_edge_distance,
7558-
min_invalid_points=min_invalid_points_int,
75597530
shape=shape,
75607531
)
75617532

@@ -7585,8 +7556,6 @@ def test_transform(self, shape, min_valid_edge_distance, min_invalid_points, lab
75857556

75867557
# Apply transform
75877558
transform = transforms.SanitizeKeyPoints(
7588-
min_valid_edge_distance=min_valid_edge_distance,
7589-
min_invalid_points=min_invalid_points,
75907559
labels_getter=labels_getter,
75917560
)
75927561
out = transform(sample)
@@ -7644,6 +7613,7 @@ def test_edge_cases(self):
76447613
# Test empty keypoints
76457614
empty_keypoints = tv_tensors.KeyPoints(torch.empty(0, 2), canvas_size=canvas_size)
76467615
result, valid_mask = F.sanitize_keypoints(empty_keypoints)
7616+
print(empty_keypoints, result, valid_mask)
76477617
assert tuple(result.shape) == (0, 2)
76487618
assert valid_mask.shape[0] == 0
76497619

@@ -7659,43 +7629,6 @@ def test_edge_cases(self):
76597629
assert tuple(result.shape) == (0, 2)
76607630
assert not valid_mask.any()
76617631

7662-
def test_min_invalid_points_fraction(self):
7663-
"""Test min_invalid_points as a fraction."""
7664-
canvas_size = (20, 20)
7665-
7666-
# Create 3D keypoints with 4 points per object
7667-
keypoints = torch.tensor(
7668-
[
7669-
# Object 1: 1 invalid point out of 4 (25% invalid)
7670-
[[5, 5], [10, 10], [15, 15], [-1, -1]],
7671-
# Object 2: 2 invalid points out of 4 (50% invalid)
7672-
[[5, 5], [10, 10], [-1, -1], [-2, -2]],
7673-
# Object 3: 3 invalid points out of 4 (75% invalid)
7674-
[[5, 5], [-1, -1], [-2, -2], [-3, -3]],
7675-
],
7676-
dtype=torch.float32,
7677-
)
7678-
7679-
keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size)
7680-
7681-
# Test with 30% threshold - should keep object 1
7682-
result, valid_mask = F.sanitize_keypoints(keypoints, min_invalid_points=0.3)
7683-
expected_valid = torch.tensor([True, False, False])
7684-
assert_equal(valid_mask, expected_valid)
7685-
assert result.shape[0] == 1
7686-
7687-
# Test with 60% threshold - should keep objects 1 and 2
7688-
result, valid_mask = F.sanitize_keypoints(keypoints, min_invalid_points=0.6)
7689-
expected_valid = torch.tensor([True, True, False])
7690-
assert_equal(valid_mask, expected_valid)
7691-
assert result.shape[0] == 2
7692-
7693-
# Test with 100% threshold - should keep all objects
7694-
result, valid_mask = F.sanitize_keypoints(keypoints, min_invalid_points=1.0)
7695-
expected_valid = torch.tensor([True, True, True])
7696-
assert_equal(valid_mask, expected_valid)
7697-
assert result.shape[0] == 3
7698-
76997632
def test_errors_functional(self):
77007633
"""Test error conditions for the functional interface."""
77017634
good_keypoints = tv_tensors.KeyPoints([[5, 5]], canvas_size=(10, 10))
@@ -7708,16 +7641,6 @@ def test_errors_functional(self):
77087641
with pytest.raises(ValueError, match="canvas_size must be None"):
77097642
F.sanitize_keypoints(good_keypoints, canvas_size=(10, 10))
77107643

7711-
# Test invalid min_invalid_points
7712-
with pytest.raises(ValueError, match="min_invalid_points must be > 0"):
7713-
F.sanitize_keypoints(good_keypoints, min_invalid_points=0)
7714-
7715-
with pytest.raises(ValueError, match="min_invalid_points must be > 0"):
7716-
F.sanitize_keypoints(good_keypoints, min_invalid_points=-1)
7717-
7718-
with pytest.raises(ValueError, match="so min_invalid_points must be 1"):
7719-
F.sanitize_keypoints(good_keypoints, min_invalid_points=2)
7720-
77217644
def test_errors_transform(self):
77227645
"""Test error conditions for the transform class."""
77237646
good_keypoints = tv_tensors.KeyPoints([[5, 5]], canvas_size=(10, 10))
@@ -7726,10 +7649,6 @@ def test_errors_transform(self):
77267649
with pytest.raises(ValueError, match="labels_getter should either be"):
77277650
transforms.SanitizeKeyPoints(labels_getter="invalid_type") # type: ignore
77287651

7729-
# Test invalid min_invalid_points
7730-
with pytest.raises(ValueError, match="min_invalid_points must be > 0"):
7731-
transforms.SanitizeKeyPoints(min_invalid_points=0)
7732-
77337652
# Test missing labels key
77347653
with pytest.raises(ValueError, match="Could not infer where the labels are"):
77357654
bad_sample = {"keypoints": good_keypoints, "BAD_KEY": torch.tensor([0])}
@@ -7745,10 +7664,6 @@ def test_errors_transform(self):
77457664
bad_sample = {"keypoints": good_keypoints, "labels": torch.tensor([0, 1, 2])}
77467665
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)
77477666

7748-
# Test min_invalid_points > 1 for 2D keypoints
7749-
with pytest.raises(ValueError, match="so min_invalid_points must be 1"):
7750-
transforms.SanitizeKeyPoints(min_invalid_points=2)(good_keypoints)
7751-
77527667
def test_no_label(self):
77537668
"""Test transform without labels."""
77547669
img = make_image()

torchvision/transforms/v2/_misc.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -473,34 +473,19 @@ class SanitizeKeyPoints(Transform):
473473
"""Remove keypoints outside of the image area and their corresponding labels (if any).
474474
475475
This transform removes keypoints or groups of keypoints and their associated labels that
476-
have coordinates outside of their corresponding image or within ``min_valid_edge_distance`` pixels
477-
from the image edges.
476+
have coordinates outside of their corresponding image.
478477
If you would instead like to clamp such keypoints to the image edges, use
479478
:class:`~torchvision.transforms.v2.ClampKeyPoints`.
480479
481480
It is recommended to call it at the end of a pipeline, before passing the
482481
input to the models.
483482
484-
Keypoints can be passed as a set of individual keypoints of shape ``[N_points, 2]`` or as a
485-
set of objects (e.g., polygons or polygonal chains) consisting of a fixed number of keypoints
486-
of shape ``[N_objects, ..., 2]``.
483+
Keypoints can be passed as a set of individual keypoints or as a set of objects
484+
(e.g., polygons or polygonal chains) consisting of a fixed number of keypoints of shape ``[..., 2]``.
487485
When groups of keypoints are passed (i.e., an at least 3-dimensional tensor), this transform
488486
will only remove entire groups, not individual keypoints within a group.
489487
490488
Args:
491-
min_valid_edge_distance (int, optional): The minimum distance that keypoints need to be away from the closest image
492-
edge along any axis in order to be considered valid. For example, setting this to 0 will only
493-
invalidate/remove keypoints outside of the image area, while a value of 1 will also remove keypoints
494-
lying exactly on the edge.
495-
Default is 0.
496-
min_invalid_points (int or float, optional): Minimum number or fraction of invalid keypoints required
497-
for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints
498-
if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints.
499-
If a float in ``(0.0, 1.0]`` is passed, it represents a fraction of the total number of keypoints in
500-
the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints.
501-
Note that a value of `1` (integer) is very different from `1.0` (float). The former will remove groups
502-
with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid.
503-
Default is 1.
504489
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
505490
(or anything else that needs to be sanitized along with the keypoints).
506491
If set to the string ``"default"``, this will try to find a "labels" key in the input (case-insensitive), if
@@ -516,19 +501,12 @@ class SanitizeKeyPoints(Transform):
516501

517502
def __init__(
518503
self,
519-
min_valid_edge_distance: int = 0,
520-
min_invalid_points: Union[int, float] = 1,
521504
labels_getter: Union[Callable[[Any], Any], str, None] = None,
522505
) -> None:
523506
super().__init__()
524-
self.min_valid_edge_distance = min_valid_edge_distance
525-
self.min_invalid_points = min_invalid_points
526507
self.labels_getter = labels_getter
527508
self._labels_getter = _parse_labels_getter(labels_getter)
528509

529-
if min_invalid_points <= 0:
530-
raise ValueError(f"min_invalid_points must be > 0. Got {min_invalid_points}.")
531-
532510
def forward(self, *inputs: Any) -> Any:
533511
inputs = inputs if len(inputs) > 1 else inputs[0]
534512

@@ -559,8 +537,6 @@ def forward(self, *inputs: Any) -> Any:
559537
valid = F._misc._get_sanitize_keypoints_mask(
560538
points,
561539
canvas_size=points.canvas_size,
562-
min_valid_edge_distance=self.min_valid_edge_distance,
563-
min_invalid_points=self.min_invalid_points,
564540
)
565541

566542
params = dict(valid=valid, labels=labels)

0 commit comments

Comments
 (0)