Skip to content

Commit dd70d57

Browse files
Add SanitizeKeyPoints transform to remove keypoints outside of the image area (#9235)
Co-authored-by: Antoine Simoulin <[email protected]> Co-authored-by: Antoine Simoulin <[email protected]>
1 parent 39cfd8e commit dd70d57

File tree

7 files changed

+510
-2
lines changed

7 files changed

+510
-2
lines changed

docs/source/transforms.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ Miscellaneous
413413
v2.RandomErasing
414414
v2.Lambda
415415
v2.SanitizeBoundingBoxes
416+
v2.SanitizeKeyPoints
416417
v2.ClampBoundingBoxes
417418
v2.ClampKeyPoints
418419
v2.UniformTemporalSubsample
@@ -427,6 +428,7 @@ Functionals
427428
v2.functional.normalize
428429
v2.functional.erase
429430
v2.functional.sanitize_bounding_boxes
431+
v2.functional.sanitize_keypoints
430432
v2.functional.clamp_bounding_boxes
431433
v2.functional.clamp_keypoints
432434
v2.functional.uniform_temporal_subsample
@@ -530,6 +532,7 @@ Developer tools
530532
v2.query_size
531533
v2.query_chw
532534
v2.get_bounding_boxes
535+
v2.get_keypoints
533536

534537

535538
V1 API Reference

test/test_transforms_v2.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7397,6 +7397,326 @@ def test_errors_functional(self):
73977397
F.sanitize_bounding_boxes(good_bbox.tolist())
73987398

73997399

7400+
class TestSanitizeKeyPoints:
7401+
def _make_keypoints_with_validity(
7402+
self,
7403+
canvas_size=(100, 100),
7404+
shape="2d", # "2d", "3d", "4d" for different keypoint shapes
7405+
):
7406+
"""Create keypoints with known validity for testing."""
7407+
canvas_h, canvas_w = canvas_size
7408+
7409+
if shape == "2d": # [N_points, 2]
7410+
keypoints_data = [
7411+
([5, 5], True), # Valid point inside image
7412+
([canvas_w - 6, canvas_h - 6], True), # Valid point near corner
7413+
([canvas_w // 2, canvas_h // 2], True), # Valid point in center
7414+
([-1, canvas_h // 2], False), # Invalid: x < 0
7415+
([canvas_w // 2, -1], False), # Invalid: y < 0
7416+
([canvas_w, canvas_h // 2], False), # Invalid: x >= canvas_w
7417+
([canvas_w // 2, canvas_h], False), # Invalid: y >= canvas_h
7418+
([0, 0], True), # Edge case: exactly on edge
7419+
([canvas_w - 1, canvas_h - 1], True), # Edge case: exactly on edge
7420+
]
7421+
points, validity = zip(*keypoints_data)
7422+
keypoints = torch.tensor(points, dtype=torch.float32)
7423+
7424+
elif shape == "3d": # [N_objects, N_points, 2]
7425+
# Create groups of keypoints with different validity patterns
7426+
keypoints_data = [
7427+
# Group 1: All points valid
7428+
([[10, 10], [20, 20], [30, 30]], True),
7429+
# Group 2: One invalid point (should be removed if min_invalid_points=1)
7430+
([[10, 10], [20, 20], [-5, 30]], False),
7431+
# Group 3: All points invalid
7432+
([[-1, -1], [-2, -2], [-3, -3]], False),
7433+
# Group 4: Mix of valid and invalid (depends on min_invalid_points)
7434+
([[10, 10], [-1, 20], [-2, 30]], False),
7435+
]
7436+
groups, validity = zip(*keypoints_data)
7437+
keypoints = torch.tensor(groups, dtype=torch.float32)
7438+
7439+
elif shape == "4d": # [N_objects, N_bones, 2, 2]
7440+
# Create bone-like structures (pairs of points)
7441+
keypoints_data = [
7442+
# Object 1: All bones valid
7443+
([[[10, 10], [15, 15]], [[20, 20], [25, 25]]], True),
7444+
# Object 2: One bone with invalid point
7445+
([[[10, 10], [15, 15]], [[-1, 20], [25, 25]]], False),
7446+
# Object 3: All bones invalid
7447+
([[[-1, -1], [-2, -2]], [[-3, -3], [-4, -4]]], False),
7448+
]
7449+
objects, validity = zip(*keypoints_data)
7450+
keypoints = torch.tensor(objects, dtype=torch.float32)
7451+
7452+
else:
7453+
raise ValueError(f"Unsupported shape: {shape}")
7454+
7455+
return keypoints, validity
7456+
7457+
@pytest.mark.parametrize("shape", ["2d", "3d", "4d"])
7458+
@pytest.mark.parametrize("input_type", [torch.Tensor, tv_tensors.KeyPoints])
7459+
def test_functional(self, shape, input_type):
7460+
"""Test the sanitize_keypoints functional interface."""
7461+
7462+
# Create inputs
7463+
canvas_size = (50, 50)
7464+
keypoints, expected_validity = self._make_keypoints_with_validity(
7465+
canvas_size=canvas_size,
7466+
shape=shape,
7467+
)
7468+
7469+
if input_type is tv_tensors.KeyPoints:
7470+
keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size)
7471+
canvas_size_arg = None
7472+
else:
7473+
canvas_size_arg = canvas_size
7474+
7475+
# Apply function to be tested
7476+
result_keypoints, valid_mask = F.sanitize_keypoints(
7477+
keypoints,
7478+
canvas_size=canvas_size_arg,
7479+
)
7480+
7481+
# Check return types
7482+
assert isinstance(result_keypoints, input_type)
7483+
assert isinstance(valid_mask, torch.Tensor)
7484+
assert valid_mask.dtype == torch.bool
7485+
7486+
# Check that valid mask matches expected validity
7487+
assert_equal(valid_mask, torch.tensor(expected_validity))
7488+
7489+
# Check that result has correct number of valid keypoints
7490+
assert result_keypoints.shape[0] == valid_mask.sum().item()
7491+
7492+
# Check that remaining keypoints shape is preserved
7493+
assert result_keypoints.shape[1:] == keypoints.shape[1:]
7494+
7495+
@pytest.mark.parametrize("shape", ["2d", "3d", "4d"])
7496+
def test_kernel(self, shape):
7497+
"""Test kernel functionality."""
7498+
canvas_size = (30, 30)
7499+
keypoints, _ = self._make_keypoints_with_validity(canvas_size=canvas_size, shape=shape)
7500+
7501+
check_kernel(
7502+
F.sanitize_keypoints,
7503+
input=keypoints,
7504+
canvas_size=canvas_size,
7505+
check_batched_vs_unbatched=False, # This function doesn't support batching
7506+
)
7507+
7508+
@pytest.mark.parametrize("shape", ["2d", "3d", "4d"])
7509+
@pytest.mark.parametrize(
7510+
"labels_getter",
7511+
(
7512+
"default",
7513+
lambda inputs: inputs["labels"],
7514+
lambda inputs: (inputs["labels"], inputs["other_labels"]),
7515+
lambda inputs: [inputs["labels"], inputs["other_labels"]],
7516+
None,
7517+
lambda inputs: None,
7518+
),
7519+
)
7520+
@pytest.mark.parametrize("sample_type", (tuple, dict))
7521+
def test_transform(self, shape, labels_getter, sample_type):
7522+
"""Test the SanitizeKeyPoints transform class."""
7523+
if sample_type is tuple and not isinstance(labels_getter, str):
7524+
# Lambda-based labels_getter doesn't work with tuple input
7525+
return
7526+
7527+
canvas_size = (40, 40)
7528+
keypoints, expected_validity = self._make_keypoints_with_validity(
7529+
canvas_size=canvas_size,
7530+
shape=shape,
7531+
)
7532+
7533+
keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size)
7534+
num_keypoints = keypoints.shape[0]
7535+
7536+
# Create associated labels and other data
7537+
labels = torch.arange(num_keypoints)
7538+
other_labels = torch.arange(num_keypoints) * 2
7539+
masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_keypoints, *canvas_size)))
7540+
whatever = torch.rand(10)
7541+
input_img = torch.randint(0, 256, size=(1, 3, *canvas_size), dtype=torch.uint8)
7542+
7543+
sample = {
7544+
"image": input_img,
7545+
"labels": labels,
7546+
"keypoints": keypoints,
7547+
"other_labels": other_labels,
7548+
"whatever": whatever,
7549+
"None": None,
7550+
"masks": masks,
7551+
}
7552+
7553+
if sample_type is tuple:
7554+
img = sample.pop("image")
7555+
sample = (img, sample)
7556+
7557+
# Apply transform
7558+
transform = transforms.SanitizeKeyPoints(
7559+
labels_getter=labels_getter,
7560+
)
7561+
out = transform(sample)
7562+
7563+
# Extract outputs
7564+
if sample_type is tuple:
7565+
out_image = out[0]
7566+
out_labels = out[1]["labels"]
7567+
out_other_labels = out[1]["other_labels"]
7568+
out_keypoints = out[1]["keypoints"]
7569+
out_masks = out[1]["masks"]
7570+
out_whatever = out[1]["whatever"]
7571+
else:
7572+
out_image = out["image"]
7573+
out_labels = out["labels"]
7574+
out_other_labels = out["other_labels"]
7575+
out_keypoints = out["keypoints"]
7576+
out_masks = out["masks"]
7577+
out_whatever = out["whatever"]
7578+
7579+
# Verify unchanged elements
7580+
assert_equal(out_image, input_img)
7581+
assert_equal(out_whatever, whatever)
7582+
assert_equal(out_masks, masks)
7583+
7584+
# Verify types
7585+
assert isinstance(out_keypoints, tv_tensors.KeyPoints)
7586+
assert isinstance(out_masks, tv_tensors.Mask)
7587+
7588+
# Calculate expected valid indices
7589+
valid_indices = [i for i, is_valid in enumerate(expected_validity) if is_valid]
7590+
7591+
# Test label handling
7592+
if labels_getter is None or (callable(labels_getter) and labels_getter(sample) is None):
7593+
# Labels should be unchanged
7594+
assert out_labels is labels
7595+
assert out_other_labels is other_labels
7596+
else:
7597+
# Labels should be filtered
7598+
assert isinstance(out_labels, torch.Tensor)
7599+
assert out_keypoints.shape[0] == out_labels.shape[0]
7600+
assert out_labels.tolist() == valid_indices
7601+
7602+
if callable(labels_getter) and isinstance(labels_getter(sample), (tuple, list)):
7603+
# other_labels should also be filtered
7604+
assert_equal(out_other_labels, out_labels * 2) # Since other_labels = labels * 2
7605+
else:
7606+
# other_labels and masks should be unchanged
7607+
assert_equal(out_other_labels, other_labels)
7608+
7609+
def test_edge_cases(self):
7610+
"""Test edge cases and boundary conditions."""
7611+
canvas_size = (10, 10)
7612+
7613+
# Test empty keypoints
7614+
empty_keypoints = tv_tensors.KeyPoints(torch.empty(0, 2), canvas_size=canvas_size)
7615+
result, valid_mask = F.sanitize_keypoints(empty_keypoints)
7616+
print(empty_keypoints, result, valid_mask)
7617+
assert tuple(result.shape) == (0, 2)
7618+
assert valid_mask.shape[0] == 0
7619+
7620+
# Test single valid keypoint
7621+
single_valid = tv_tensors.KeyPoints([[5, 5]], canvas_size=canvas_size)
7622+
result, valid_mask = F.sanitize_keypoints(single_valid)
7623+
assert tuple(result.shape) == (1, 2)
7624+
assert valid_mask.all()
7625+
7626+
# Test single invalid keypoint
7627+
single_invalid = tv_tensors.KeyPoints([[-1, -1]], canvas_size=canvas_size)
7628+
result, valid_mask = F.sanitize_keypoints(single_invalid)
7629+
assert tuple(result.shape) == (0, 2)
7630+
assert not valid_mask.any()
7631+
7632+
def test_errors_functional(self):
7633+
"""Test error conditions for the functional interface."""
7634+
good_keypoints = tv_tensors.KeyPoints([[5, 5]], canvas_size=(10, 10))
7635+
7636+
# Test missing canvas_size for pure tensor
7637+
with pytest.raises(ValueError, match="canvas_size cannot be None"):
7638+
F.sanitize_keypoints(good_keypoints.as_subclass(torch.Tensor), canvas_size=None)
7639+
7640+
# Test canvas_size provided for tv_tensor
7641+
with pytest.raises(ValueError, match="canvas_size must be None"):
7642+
F.sanitize_keypoints(good_keypoints, canvas_size=(10, 10))
7643+
7644+
def test_errors_transform(self):
7645+
"""Test error conditions for the transform class."""
7646+
good_keypoints = tv_tensors.KeyPoints([[5, 5]], canvas_size=(10, 10))
7647+
7648+
# Test invalid labels_getter
7649+
with pytest.raises(ValueError, match="labels_getter should either be"):
7650+
transforms.SanitizeKeyPoints(labels_getter="invalid_type") # type: ignore
7651+
7652+
# Test missing labels key
7653+
with pytest.raises(ValueError, match="Could not infer where the labels are"):
7654+
bad_sample = {"keypoints": good_keypoints, "BAD_KEY": torch.tensor([0])}
7655+
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)
7656+
7657+
# Test labels not a tensor
7658+
with pytest.raises(ValueError, match="must be a tensor"):
7659+
bad_sample = {"keypoints": good_keypoints, "labels": [0]}
7660+
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)
7661+
7662+
# Test mismatched sizes
7663+
with pytest.raises(ValueError, match="Number of"):
7664+
bad_sample = {"keypoints": good_keypoints, "labels": torch.tensor([0, 1, 2])}
7665+
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)
7666+
7667+
def test_no_label(self):
7668+
"""Test transform without labels."""
7669+
img = make_image()
7670+
keypoints = make_keypoints()
7671+
7672+
# Should raise error without labels_getter=None
7673+
with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
7674+
transforms.SanitizeKeyPoints(labels_getter="default")(img, keypoints)
7675+
7676+
# Should work with labels_getter=None
7677+
out_img, out_keypoints = transforms.SanitizeKeyPoints(labels_getter=None)(img, keypoints)
7678+
assert isinstance(out_img, tv_tensors.Image)
7679+
assert isinstance(out_keypoints, tv_tensors.KeyPoints)
7680+
7681+
@pytest.mark.parametrize("device", cpu_and_cuda())
7682+
def test_device_and_dtype_consistency(self, device):
7683+
"""Test that device and dtype are preserved."""
7684+
canvas_size = (20, 20)
7685+
keypoints = torch.tensor([[5, 5], [15, 15], [-1, -1]], dtype=torch.float32, device=device)
7686+
keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size)
7687+
7688+
result, valid_mask = F.sanitize_keypoints(keypoints)
7689+
7690+
assert result.device == keypoints.device
7691+
assert result.dtype == keypoints.dtype
7692+
assert valid_mask.device == keypoints.device
7693+
7694+
def test_keypoint_shapes_consistency(self):
7695+
"""Test that different keypoint shapes are handled correctly."""
7696+
canvas_size = (50, 50)
7697+
7698+
# Test 2D shape [N_points, 2]
7699+
kp_2d = torch.tensor([[10, 10], [20, 20], [-1, -1]], dtype=torch.float32)
7700+
kp_2d = tv_tensors.KeyPoints(kp_2d, canvas_size=canvas_size)
7701+
result_2d, valid_2d = F.sanitize_keypoints(kp_2d)
7702+
assert result_2d.ndim == 2
7703+
assert result_2d.shape[1:] == kp_2d.shape[1:]
7704+
7705+
# Test 3D shape [N_objects, N_points, 2]
7706+
kp_3d = torch.tensor([[[10, 10], [20, 20]], [[-1, -1], [30, 30]]], dtype=torch.float32)
7707+
kp_3d = tv_tensors.KeyPoints(kp_3d, canvas_size=canvas_size)
7708+
result_3d, valid_3d = F.sanitize_keypoints(kp_3d)
7709+
assert result_3d.ndim == 3
7710+
assert result_3d.shape[1:] == kp_3d.shape[1:]
7711+
7712+
# Test 4D shape [N_objects, N_bones, 2, 2]
7713+
kp_4d = torch.tensor([[[[10, 10], [20, 20]]], [[[-1, -1], [30, 30]]]], dtype=torch.float32)
7714+
kp_4d = tv_tensors.KeyPoints(kp_4d, canvas_size=canvas_size)
7715+
result_4d, valid_4d = F.sanitize_keypoints(kp_4d)
7716+
assert result_4d.ndim == 4
7717+
assert result_4d.shape[1:] == kp_4d.shape[1:]
7718+
7719+
74007720
class TestJPEG:
74017721
@pytest.mark.parametrize("quality", [5, 75])
74027722
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])

torchvision/transforms/v2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@
5151
LinearTransformation,
5252
Normalize,
5353
SanitizeBoundingBoxes,
54+
SanitizeKeyPoints,
5455
ToDtype,
5556
)
5657
from ._temporal import UniformTemporalSubsample
5758
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
58-
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
59+
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size
5960

6061
from ._deprecated import ToTensor # usort: skip

0 commit comments

Comments
 (0)