From 51a6574117749e6441cd15b2ade27c14f63e031d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Mon, 6 Oct 2025 16:02:09 +0200 Subject: [PATCH 1/6] Add SanitizeKeyPoints transform --- docs/source/transforms.rst | 3 + torchvision/transforms/v2/__init__.py | 3 +- torchvision/transforms/v2/_misc.py | 124 ++++++++++++++++- torchvision/transforms/v2/_utils.py | 12 ++ .../transforms/v2/functional/__init__.py | 1 + torchvision/transforms/v2/functional/_misc.py | 125 ++++++++++++++++++ 6 files changed, 266 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 44b4cc3aaa5..529815ead9a 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -413,6 +413,7 @@ Miscellaneous v2.RandomErasing v2.Lambda v2.SanitizeBoundingBoxes + v2.SanitizeKeyPoints v2.ClampBoundingBoxes v2.ClampKeyPoints v2.UniformTemporalSubsample @@ -427,6 +428,7 @@ Functionals v2.functional.normalize v2.functional.erase v2.functional.sanitize_bounding_boxes + v2.functional.sanitize_keypoints v2.functional.clamp_bounding_boxes v2.functional.clamp_keypoints v2.functional.uniform_temporal_subsample @@ -530,6 +532,7 @@ Developer tools v2.query_size v2.query_chw v2.get_bounding_boxes + v2.get_keypoints V1 API Reference diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 408065dab94..895bf6e2f71 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -51,10 +51,11 @@ LinearTransformation, Normalize, SanitizeBoundingBoxes, + SanitizeKeyPoints, ToDtype, ) from ._temporal import UniformTemporalSubsample from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor -from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size +from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 875f65d581c..aa1661c4eff 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -10,7 +10,15 @@ from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F, Transform -from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor +from ._utils import ( + _parse_labels_getter, + _setup_number_or_seq, + _setup_size, + get_bounding_boxes, + get_keypoints, + has_any, + is_pure_tensor, +) # TODO: do we want/need to expose this? @@ -459,3 +467,117 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return output else: return tv_tensors.wrap(output, like=inpt) + + +class SanitizeKeyPoints(Transform): + """Remove keypoints outside of the image area and their corresponding labels (if any). + + This transform removes keypoints or groups of keypoints and their associated labels that + have coordinates outside of their corresponding image or within ``min_valid_edge_distance`` pixels + from the image edges. + If you would instead like to clamp such keypoints to the image edges, use + :class:`~torchvision.transforms.v2.ClampKeyPoints`. + + It is recommended to call it at the end of a pipeline, before passing the + input to the models. + + Keypoints can be passed as a set of individual keypoints of shape ``[N_points, 2]`` or as a + set of objects (e.g., polygons or polygonal chains) consisting of a fixed number of keypoints + of shape ``[N_objects, ..., 2]``. + When groups of keypoints are passed (i.e., an at least 3-dimensional tensor), this transform + will only remove entire groups, not individual keypoints within a group. + + Args: + min_valid_edge_distance (int, optional): The minimum distance that keypoints need to be away from the closest image + edge along any axis in order to be considered valid. For example, setting this to 0 will only + invalidate/remove keypoints outside of the image area, while a value of 1 will also remove keypoints + lying exactly on the edge. + Default is 0. + min_invalid_points (int or float, optional): Minimum number or fraction of invalid keypoints required + for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints + if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints. + If a float in (0.0, 1.0) is passed, it represents a fraction of the total number of keypoints in + the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints. + Note that a value of `1` (integer) is very different from `1.0` (float). The former will remove groups + with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid. + Default is 1. + labels_getter (callable or str or None, optional): indicates how to identify the labels in the input + (or anything else that needs to be sanitized along with the keypoints). + By default, this will try to find a "labels" key in the input (case-insensitive), if + the input is a dict or it is a tuple whose second element is a dict. + + It can also be a callable that takes the same input as the transform, and returns either: + + - A single tensor (the labels) + - A tuple/list of tensors, each of which will be subject to the same sanitization as the keypoints. + + If ``labels_getter`` is None then only keypoints are sanitized. + """ + + def __init__( + self, + min_valid_edge_distance: int = 0, + min_invalid_points: int | float = 1, + labels_getter: Union[Callable[[Any], Any], str, None] = "default", + ) -> None: + super().__init__() + self.min_valid_edge_distance = min_valid_edge_distance + self.min_invalid_points = min_invalid_points + self.labels_getter = labels_getter + self._labels_getter = _parse_labels_getter(labels_getter) + + if min_invalid_points <= 0: + raise ValueError(f"min_invalid_points must be > 0. Got {min_invalid_points}.") + + def forward(self, *inputs: Any) -> Any: + inputs = inputs if len(inputs) > 1 else inputs[0] + + labels = self._labels_getter(inputs) + if labels is not None: + msg = "The labels in the input to forward() must be a tensor or None, got {type} instead." + if isinstance(labels, torch.Tensor): + labels = (labels,) + elif isinstance(labels, (tuple, list)): + for entry in labels: + if not isinstance(entry, torch.Tensor): + # TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask] + raise ValueError(msg.format(type=type(entry))) + else: + raise ValueError(msg.format(type=type(labels))) + + flat_inputs, spec = tree_flatten(inputs) + points = get_keypoints(flat_inputs) + + if labels is not None: + for label in labels: + if points.shape[0] != label.shape[0]: + raise ValueError( + f"Number of kepyoints (shape={points.shape}) must match the number of labels." + f"Found labels with shape={label.shape})." + ) + + valid = F._misc._get_sanitize_keypoints_mask( + points, + canvas_size=points.canvas_size, + min_valid_edge_distance=self.min_valid_edge_distance, + min_invalid_points=self.min_invalid_points, + ) + + params = dict(valid=valid, labels=labels) + flat_outputs = [self.transform(inpt, params) for inpt in flat_inputs] + + return tree_unflatten(flat_outputs, spec) + + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: + is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) + is_keypoints = isinstance(inpt, tv_tensors.KeyPoints) + + if not (is_label or is_keypoints): + return inpt + + output = inpt[params["valid"]] + + if is_label: + return output + else: + return tv_tensors.wrap(output, like=inpt) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 5ed871d0554..bb6051b4e61 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -165,6 +165,18 @@ def get_bounding_boxes(flat_inputs: list[Any]) -> tv_tensors.BoundingBoxes: raise ValueError("No bounding boxes were found in the sample") +def get_keypoints(flat_inputs: list[Any]) -> tv_tensors.KeyPoints: + """Return the keypoints in the input. + + Assumes only one ``KeyPoints`` object is present. + """ + # This assumes there is only one keypoint per sample as per the general convention + try: + return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.KeyPoints)) + except StopIteration: + raise ValueError("No keypoints were found in the sample") + + def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: """Return Channel, Height, and Width.""" chws = { diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 96767d30c99..13fbaa588fe 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -156,6 +156,7 @@ normalize_image, normalize_video, sanitize_bounding_boxes, + sanitize_keypoints, to_dtype, to_dtype_image, to_dtype_video, diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 7987b034ae8..8bac75f6ac2 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -442,3 +442,128 @@ def _get_sanitize_bounding_boxes_mask( valid &= (bounding_boxes[..., 4] <= image_w) & (bounding_boxes[..., 5] <= image_h) valid &= (bounding_boxes[..., 6] <= image_w) & (bounding_boxes[..., 7] <= image_h) return valid + + +def sanitize_keypoints( + key_points: torch.Tensor, + canvas_size: Optional[tuple[int, int]] = None, + min_valid_edge_distance: int = 0, + min_invalid_points: int | float = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """Remove keypoints outside of the image area and their corresponding labels (if any). + + This transform removes keypoints or groups of keypoints and their associated labels that + have coordinates outside of their corresponding image or within ``min_valid_edge_distance`` pixels + from the image edges. + If you would instead like to clamp such keypoints to the image edges, use + :class:`~torchvision.transforms.v2.ClampKeyPoints`. + + It is recommended to call it at the end of a pipeline, before passing the + input to the models. + + Keypoints can be passed as a set of individual keypoints of shape ``[N_points, 2]`` or as a + set of objects (e.g., polygons or polygonal chains) consisting of a fixed number of keypoints + of shape ``[N_objects, ..., 2]``. + When groups of keypoints are passed (i.e., an at least 3-dimensional tensor), this transform + will only remove entire groups, not individual keypoints within a group. + + Args: + key_points (Tensor or :class:`~torchvision.tv_tensors.KeyPoints`): The keypoints to be sanitized. + canvas_size (tuple of int, optional): The canvas_size of the keypoints + (size of the corresponding image/video). + Must be left to none if ``key_points`` is a :class:`~torchvision.tv_tensors.KeyPoints` object. + min_valid_edge_distance (int, optional): The minimum distance that keypoints need to be away from the closest image + edge along any axis in order to be considered valid. For example, setting this to 0 will only + invalidate/remove keypoints outside of the image area, while a value of 1 will also remove keypoints + lying exactly on the edge. + Default is 0. + min_invalid_points (int or float, optional): Minimum number or fraction of invalid keypoints required + for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints + if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints. + If a float in (0.0, 1.0) is passed, it represents a fraction of the total number of keypoints in + the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints. + Note that a value of `1` (integer) is very different from `1.0` (float). The former will remove groups + with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid. + Default is 1. + + Returns: + out (tuple of Tensors): The subset of valid keypoints, and the corresponding indexing mask. + The mask can then be used to subset other tensors (e.g. labels) that are associated with the keypoints. + """ + if torch.jit.is_scripting() or is_pure_tensor(key_points): + if canvas_size is None: + raise ValueError( + "canvas_size cannot be None if key_points is a pure tensor. " + "Set it to an appropriate value or pass key_points as a tv_tensors.KeyPoints object." + ) + valid = _get_sanitize_keypoints_mask( + key_points, + canvas_size=canvas_size, + min_valid_edge_distance=min_valid_edge_distance, + min_invalid_points=min_invalid_points, + ) + key_points = key_points[valid] + else: + if not isinstance(key_points, tv_tensors.KeyPoints): + raise ValueError("key_points must be a tv_tensors.KeyPoints instance or a pure tensor.") + if canvas_size is not None: + raise ValueError( + "canvas_size must be None when key_points is a tv_tensors.KeyPoints instance. " + f"Got canvas_size={canvas_size}. " + "Leave it to None or pass key_points as a pure tensor." + ) + valid = _get_sanitize_keypoints_mask( + key_points, + canvas_size=key_points.canvas_size, + min_valid_edge_distance=min_valid_edge_distance, + min_invalid_points=min_invalid_points, + ) + key_points = tv_tensors.wrap(key_points[valid], like=key_points) + + return key_points, valid + + +def _get_sanitize_keypoints_mask( + key_points: torch.Tensor, + canvas_size: tuple[int, int], + min_valid_edge_distance: int = 0, + min_invalid_points: int | float = 1, +) -> torch.Tensor: + + image_h, image_w = canvas_size + + # Bring keypoints tensor into canonical shape [N_instances, N_points, 2] + if key_points.ndim == 2: + key_points = key_points.unsqueeze(dim=1) + elif key_points.ndim > 3: + key_points = key_points.flatten(start_dim=1, end_dim=-2) + + # Convert min_invalid_points from relative to absolute number of points + if min_invalid_points <= 0: + raise ValueError(f"min_invalid_points must be > 0. Got {min_invalid_points}.") + if isinstance(min_invalid_points, float): + min_invalid_points = math.ceil(min_invalid_points * key_points.shape[1]) + if min_invalid_points > 1 and key_points.shape[1] == 1: + raise ValueError( + f"min_invalid_points was set to {min_invalid_points}, but key_points only contains a single point per " + "instance, so min_invalid_points must be 1." + ) + + # Compute distance of each point to the closest image edge + dists = torch.stack( + [ + key_points[..., 0], # x + image_w - 1 - key_points[..., 0], # image_w - x + key_points[..., 1], # y + image_h - 1 - key_points[..., 1], # image_h - y + ], + dim=-1, + ) + dists = dists.min(dim=-1).values # [N_instances, N_points] + + # Determine invalid points + invalid_points = dists < min_valid_edge_distance # [N_instances, N_points] + + # Determine valid instances + valid = invalid_points.sum(dim=-1) < min_invalid_points # [N_instances] + return valid From 125ab397425d739d10d16b469b3dc8b8e107f6c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Mon, 6 Oct 2025 17:55:10 +0200 Subject: [PATCH 2/6] Unit tests for sanitize_keypoints --- test/test_transforms_v2.py | 406 +++++++++++++++++++++++++++++++++++++ 1 file changed, 406 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f92f2a0bc67..e419820bd18 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -7404,6 +7404,412 @@ def test_errors_functional(self): F.sanitize_bounding_boxes(good_bbox.tolist()) +class TestSanitizeKeyPoints: + def _make_keypoints_with_validity( + self, + canvas_size=(100, 100), + min_valid_edge_distance=0, + min_invalid_points=1, + shape="2d", # "2d", "3d", "4d" for different keypoint shapes + ): + """Create keypoints with known validity for testing.""" + canvas_h, canvas_w = canvas_size + + if shape == "2d": # [N_points, 2] + keypoints_data = [ + ([5, 5], min_valid_edge_distance <= 5), # Valid point inside image + ([canvas_w - 6, canvas_h - 6], min_valid_edge_distance <= 5), # Valid point near corner + ([canvas_w // 2, canvas_h // 2], True), # Valid point in center + ([-1, canvas_h // 2], False), # Invalid: x < 0 + ([canvas_w // 2, -1], False), # Invalid: y < 0 + ([canvas_w, canvas_h // 2], False), # Invalid: x >= canvas_w + ([canvas_w // 2, canvas_h], False), # Invalid: y >= canvas_h + ([0, 0], min_valid_edge_distance <= 0), # Edge case: exactly on edge + ([canvas_w - 1, canvas_h - 1], min_valid_edge_distance <= 0), # Edge case: exactly on edge + ] + points, validity = zip(*keypoints_data) + keypoints = torch.tensor(points, dtype=torch.float32) + + elif shape == "3d": # [N_objects, N_points, 2] + # Create groups of keypoints with different validity patterns + keypoints_data = [ + # Group 1: All points valid + ([[10, 10], [20, 20], [30, 30]], True), + # Group 2: One invalid point (should be removed if min_invalid_points=1) + ([[10, 10], [20, 20], [-5, 30]], min_invalid_points > 1), + # Group 3: All points invalid + ([[-1, -1], [-2, -2], [-3, -3]], False), + # Group 4: Mix of valid and invalid (depends on min_invalid_points) + ([[10, 10], [-1, 20], [-2, 30]], min_invalid_points > 2), + ] + groups, validity = zip(*keypoints_data) + keypoints = torch.tensor(groups, dtype=torch.float32) + + elif shape == "4d": # [N_objects, N_bones, 2, 2] + # Create bone-like structures (pairs of points) + keypoints_data = [ + # Object 1: All bones valid + ([[[10, 10], [15, 15]], [[20, 20], [25, 25]]], True), + # Object 2: One bone with invalid point + ([[[10, 10], [15, 15]], [[-1, 20], [25, 25]]], min_invalid_points > 1), + # Object 3: All bones invalid + ([[[-1, -1], [-2, -2]], [[-3, -3], [-4, -4]]], False), + ] + objects, validity = zip(*keypoints_data) + keypoints = torch.tensor(objects, dtype=torch.float32) + + else: + raise ValueError(f"Unsupported shape: {shape}") + + return keypoints, validity + + @pytest.mark.parametrize("shape", ["2d", "3d", "4d"]) + @pytest.mark.parametrize("min_valid_edge_distance", [0, 1, 5, 6]) + @pytest.mark.parametrize("min_invalid_points", [1, 2, 0.5]) + @pytest.mark.parametrize("input_type", [torch.Tensor, tv_tensors.KeyPoints]) + def test_functional(self, shape, min_valid_edge_distance, min_invalid_points, input_type): + """Test the sanitize_keypoints functional interface.""" + # Check for invalid configuration + if shape == "2d" and min_invalid_points > 1: + pytest.xfail("min_invalid_points > 1 does not make sense for 2D keypoints") + + # Create inputs + canvas_size = (50, 50) + if isinstance(min_invalid_points, float): + num_groups = 4 if shape == "4d" else 3 + min_invalid_points_int = math.ceil(min_invalid_points * num_groups) + else: + min_invalid_points_int = min_invalid_points + keypoints, expected_validity = self._make_keypoints_with_validity( + canvas_size=canvas_size, + min_valid_edge_distance=min_valid_edge_distance, + min_invalid_points=min_invalid_points_int, + shape=shape, + ) + + if input_type is tv_tensors.KeyPoints: + keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size) + canvas_size_arg = None + else: + canvas_size_arg = canvas_size + + # Apply function to be tested + result_keypoints, valid_mask = F.sanitize_keypoints( + keypoints, + canvas_size=canvas_size_arg, + min_valid_edge_distance=min_valid_edge_distance, + min_invalid_points=min_invalid_points, + ) + + # Check return types + assert isinstance(result_keypoints, input_type) + assert isinstance(valid_mask, torch.Tensor) + assert valid_mask.dtype == torch.bool + + # Check that valid mask matches expected validity + assert_equal(valid_mask, torch.tensor(expected_validity)) + + # Check that result has correct number of valid keypoints + assert result_keypoints.shape[0] == valid_mask.sum().item() + + # Check that remaining keypoints shape is preserved + assert result_keypoints.shape[1:] == keypoints.shape[1:] + + @pytest.mark.parametrize("shape", ["2d", "3d", "4d"]) + def test_kernel(self, shape): + """Test kernel functionality.""" + canvas_size = (30, 30) + keypoints, _ = self._make_keypoints_with_validity(canvas_size=canvas_size, shape=shape) + + check_kernel( + F.sanitize_keypoints, + input=keypoints, + canvas_size=canvas_size, + check_batched_vs_unbatched=False, # This function doesn't support batching + ) + + @pytest.mark.parametrize("shape", ["2d", "3d", "4d"]) + @pytest.mark.parametrize("min_valid_edge_distance", [0, 2]) + @pytest.mark.parametrize("min_invalid_points", [1, 0.3]) + @pytest.mark.parametrize( + "labels_getter", + ( + "default", + lambda inputs: inputs["labels"], + lambda inputs: (inputs["labels"], inputs["other_labels"]), + lambda inputs: [inputs["labels"], inputs["other_labels"]], + None, + lambda inputs: None, + ), + ) + @pytest.mark.parametrize("sample_type", (tuple, dict)) + def test_transform(self, shape, min_valid_edge_distance, min_invalid_points, labels_getter, sample_type): + """Test the SanitizeKeyPoints transform class.""" + if sample_type is tuple and not isinstance(labels_getter, str): + # Lambda-based labels_getter doesn't work with tuple input + return + + # Check for invalid configuration + if shape == "2d" and min_invalid_points > 1: + pytest.xfail("min_invalid_points > 1 does not make sense for 2D keypoints") + + canvas_size = (40, 40) + if isinstance(min_invalid_points, float): + num_groups = 4 if shape == "4d" else 3 + min_invalid_points_int = math.ceil(min_invalid_points * num_groups) + else: + min_invalid_points_int = min_invalid_points + keypoints, expected_validity = self._make_keypoints_with_validity( + canvas_size=canvas_size, + min_valid_edge_distance=min_valid_edge_distance, + min_invalid_points=min_invalid_points_int, + shape=shape, + ) + + keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size) + num_keypoints = keypoints.shape[0] + + # Create associated labels and other data + labels = torch.arange(num_keypoints) + other_labels = torch.arange(num_keypoints) * 2 + masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_keypoints, *canvas_size))) + whatever = torch.rand(10) + input_img = torch.randint(0, 256, size=(1, 3, *canvas_size), dtype=torch.uint8) + + sample = { + "image": input_img, + "labels": labels, + "keypoints": keypoints, + "other_labels": other_labels, + "whatever": whatever, + "None": None, + "masks": masks, + } + + if sample_type is tuple: + img = sample.pop("image") + sample = (img, sample) + + # Apply transform + transform = transforms.SanitizeKeyPoints( + min_valid_edge_distance=min_valid_edge_distance, + min_invalid_points=min_invalid_points, + labels_getter=labels_getter, + ) + out = transform(sample) + + # Extract outputs + if sample_type is tuple: + out_image = out[0] + out_labels = out[1]["labels"] + out_other_labels = out[1]["other_labels"] + out_keypoints = out[1]["keypoints"] + out_masks = out[1]["masks"] + out_whatever = out[1]["whatever"] + else: + out_image = out["image"] + out_labels = out["labels"] + out_other_labels = out["other_labels"] + out_keypoints = out["keypoints"] + out_masks = out["masks"] + out_whatever = out["whatever"] + + # Verify unchanged elements + assert_equal(out_image, input_img) + assert_equal(out_whatever, whatever) + assert_equal(out_masks, masks) + + # Verify types + assert isinstance(out_keypoints, tv_tensors.KeyPoints) + assert isinstance(out_masks, tv_tensors.Mask) + + # Calculate expected valid indices + valid_indices = [i for i, is_valid in enumerate(expected_validity) if is_valid] + + # Test label handling + if labels_getter is None or (callable(labels_getter) and labels_getter(sample) is None): + # Labels should be unchanged + assert out_labels is labels + assert out_other_labels is other_labels + else: + # Labels should be filtered + assert isinstance(out_labels, torch.Tensor) + assert out_keypoints.shape[0] == out_labels.shape[0] + assert out_labels.tolist() == valid_indices + + if callable(labels_getter) and isinstance(labels_getter(sample), (tuple, list)): + # other_labels should also be filtered + assert_equal(out_other_labels, out_labels * 2) # Since other_labels = labels * 2 + else: + # other_labels and masks should be unchanged + assert_equal(out_other_labels, other_labels) + + def test_edge_cases(self): + """Test edge cases and boundary conditions.""" + canvas_size = (10, 10) + + # Test empty keypoints + empty_keypoints = tv_tensors.KeyPoints(torch.empty(0, 2), canvas_size=canvas_size) + result, valid_mask = F.sanitize_keypoints(empty_keypoints) + assert tuple(result.shape) == (0, 2) + assert valid_mask.shape[0] == 0 + + # Test single valid keypoint + single_valid = tv_tensors.KeyPoints([[5, 5]], canvas_size=canvas_size) + result, valid_mask = F.sanitize_keypoints(single_valid) + assert tuple(result.shape) == (1, 2) + assert valid_mask.all() + + # Test single invalid keypoint + single_invalid = tv_tensors.KeyPoints([[-1, -1]], canvas_size=canvas_size) + result, valid_mask = F.sanitize_keypoints(single_invalid) + assert tuple(result.shape) == (0, 2) + assert not valid_mask.any() + + def test_min_invalid_points_fraction(self): + """Test min_invalid_points as a fraction.""" + canvas_size = (20, 20) + + # Create 3D keypoints with 4 points per object + keypoints = torch.tensor( + [ + # Object 1: 1 invalid point out of 4 (25% invalid) + [[5, 5], [10, 10], [15, 15], [-1, -1]], + # Object 2: 2 invalid points out of 4 (50% invalid) + [[5, 5], [10, 10], [-1, -1], [-2, -2]], + # Object 3: 3 invalid points out of 4 (75% invalid) + [[5, 5], [-1, -1], [-2, -2], [-3, -3]], + ], + dtype=torch.float32, + ) + + keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size) + + # Test with 30% threshold - should keep object 1 + result, valid_mask = F.sanitize_keypoints(keypoints, min_invalid_points=0.3) + expected_valid = torch.tensor([True, False, False]) + assert_equal(valid_mask, expected_valid) + assert result.shape[0] == 1 + + # Test with 60% threshold - should keep objects 1 and 2 + result, valid_mask = F.sanitize_keypoints(keypoints, min_invalid_points=0.6) + expected_valid = torch.tensor([True, True, False]) + assert_equal(valid_mask, expected_valid) + assert result.shape[0] == 2 + + # Test with 100% threshold - should keep all objects + result, valid_mask = F.sanitize_keypoints(keypoints, min_invalid_points=1.0) + expected_valid = torch.tensor([True, True, True]) + assert_equal(valid_mask, expected_valid) + assert result.shape[0] == 3 + + def test_errors_functional(self): + """Test error conditions for the functional interface.""" + good_keypoints = tv_tensors.KeyPoints([[5, 5]], canvas_size=(10, 10)) + + # Test missing canvas_size for pure tensor + with pytest.raises(ValueError, match="canvas_size cannot be None"): + F.sanitize_keypoints(good_keypoints.as_subclass(torch.Tensor), canvas_size=None) + + # Test canvas_size provided for tv_tensor + with pytest.raises(ValueError, match="canvas_size must be None"): + F.sanitize_keypoints(good_keypoints, canvas_size=(10, 10)) + + # Test invalid min_invalid_points + with pytest.raises(ValueError, match="min_invalid_points must be > 0"): + F.sanitize_keypoints(good_keypoints, min_invalid_points=0) + + with pytest.raises(ValueError, match="min_invalid_points must be > 0"): + F.sanitize_keypoints(good_keypoints, min_invalid_points=-1) + + with pytest.raises(ValueError, match="so min_invalid_points must be 1"): + F.sanitize_keypoints(good_keypoints, min_invalid_points=2) + + def test_errors_transform(self): + """Test error conditions for the transform class.""" + good_keypoints = tv_tensors.KeyPoints([[5, 5]], canvas_size=(10, 10)) + + # Test invalid labels_getter + with pytest.raises(ValueError, match="labels_getter should either be"): + transforms.SanitizeKeyPoints(labels_getter="invalid_type") # type: ignore + + # Test invalid min_invalid_points + with pytest.raises(ValueError, match="min_invalid_points must be > 0"): + transforms.SanitizeKeyPoints(min_invalid_points=0) + + # Test missing labels key + with pytest.raises(ValueError, match="Could not infer where the labels are"): + bad_sample = {"keypoints": good_keypoints, "BAD_KEY": torch.tensor([0])} + transforms.SanitizeKeyPoints()(bad_sample) + + # Test labels not a tensor + with pytest.raises(ValueError, match="must be a tensor"): + bad_sample = {"keypoints": good_keypoints, "labels": [0]} + transforms.SanitizeKeyPoints()(bad_sample) + + # Test mismatched sizes + with pytest.raises(ValueError, match="Number of"): + bad_sample = {"keypoints": good_keypoints, "labels": torch.tensor([0, 1, 2])} + transforms.SanitizeKeyPoints()(bad_sample) + + # Test min_invalid_points > 1 for 2D keypoints + with pytest.raises(ValueError, match="so min_invalid_points must be 1"): + sample = {"keypoints": good_keypoints, "labels": torch.tensor([0])} + transforms.SanitizeKeyPoints(min_invalid_points=2)(sample) + + def test_no_label(self): + """Test transform without labels.""" + img = make_image() + keypoints = make_keypoints() + + # Should raise error without labels_getter=None + with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"): + transforms.SanitizeKeyPoints()(img, keypoints) + + # Should work with labels_getter=None + out_img, out_keypoints = transforms.SanitizeKeyPoints(labels_getter=None)(img, keypoints) + assert isinstance(out_img, tv_tensors.Image) + assert isinstance(out_keypoints, tv_tensors.KeyPoints) + + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_device_and_dtype_consistency(self, device): + """Test that device and dtype are preserved.""" + canvas_size = (20, 20) + keypoints = torch.tensor([[5, 5], [15, 15], [-1, -1]], dtype=torch.float32, device=device) + keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size) + + result, valid_mask = F.sanitize_keypoints(keypoints) + + assert result.device == keypoints.device + assert result.dtype == keypoints.dtype + assert valid_mask.device == keypoints.device + + def test_keypoint_shapes_consistency(self): + """Test that different keypoint shapes are handled correctly.""" + canvas_size = (50, 50) + + # Test 2D shape [N_points, 2] + kp_2d = torch.tensor([[10, 10], [20, 20], [-1, -1]], dtype=torch.float32) + kp_2d = tv_tensors.KeyPoints(kp_2d, canvas_size=canvas_size) + result_2d, valid_2d = F.sanitize_keypoints(kp_2d) + assert result_2d.ndim == 2 + assert result_2d.shape[1:] == kp_2d.shape[1:] + + # Test 3D shape [N_objects, N_points, 2] + kp_3d = torch.tensor([[[10, 10], [20, 20]], [[-1, -1], [30, 30]]], dtype=torch.float32) + kp_3d = tv_tensors.KeyPoints(kp_3d, canvas_size=canvas_size) + result_3d, valid_3d = F.sanitize_keypoints(kp_3d) + assert result_3d.ndim == 3 + assert result_3d.shape[1:] == kp_3d.shape[1:] + + # Test 4D shape [N_objects, N_bones, 2, 2] + kp_4d = torch.tensor([[[[10, 10], [20, 20]]], [[[-1, -1], [30, 30]]]], dtype=torch.float32) + kp_4d = tv_tensors.KeyPoints(kp_4d, canvas_size=canvas_size) + result_4d, valid_4d = F.sanitize_keypoints(kp_4d) + assert result_4d.ndim == 4 + assert result_4d.shape[1:] == kp_4d.shape[1:] + + class TestJPEG: @pytest.mark.parametrize("quality", [5, 75]) @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) From b820aad88294f4b9cd44b0d9384599f3b6664e10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Mon, 6 Oct 2025 18:13:42 +0200 Subject: [PATCH 3/6] Change default `labels_getter` of `SanitizeKeypoints` to `None`. --- test/test_transforms_v2.py | 11 +++++------ torchvision/transforms/v2/_misc.py | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e419820bd18..475c6837cfb 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -7740,22 +7740,21 @@ def test_errors_transform(self): # Test missing labels key with pytest.raises(ValueError, match="Could not infer where the labels are"): bad_sample = {"keypoints": good_keypoints, "BAD_KEY": torch.tensor([0])} - transforms.SanitizeKeyPoints()(bad_sample) + transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample) # Test labels not a tensor with pytest.raises(ValueError, match="must be a tensor"): bad_sample = {"keypoints": good_keypoints, "labels": [0]} - transforms.SanitizeKeyPoints()(bad_sample) + transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample) # Test mismatched sizes with pytest.raises(ValueError, match="Number of"): bad_sample = {"keypoints": good_keypoints, "labels": torch.tensor([0, 1, 2])} - transforms.SanitizeKeyPoints()(bad_sample) + transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample) # Test min_invalid_points > 1 for 2D keypoints with pytest.raises(ValueError, match="so min_invalid_points must be 1"): - sample = {"keypoints": good_keypoints, "labels": torch.tensor([0])} - transforms.SanitizeKeyPoints(min_invalid_points=2)(sample) + transforms.SanitizeKeyPoints(min_invalid_points=2)(good_keypoints) def test_no_label(self): """Test transform without labels.""" @@ -7764,7 +7763,7 @@ def test_no_label(self): # Should raise error without labels_getter=None with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"): - transforms.SanitizeKeyPoints()(img, keypoints) + transforms.SanitizeKeyPoints(labels_getter="default")(img, keypoints) # Should work with labels_getter=None out_img, out_keypoints = transforms.SanitizeKeyPoints(labels_getter=None)(img, keypoints) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index aa1661c4eff..e1c832f1eb7 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -503,7 +503,7 @@ class SanitizeKeyPoints(Transform): Default is 1. labels_getter (callable or str or None, optional): indicates how to identify the labels in the input (or anything else that needs to be sanitized along with the keypoints). - By default, this will try to find a "labels" key in the input (case-insensitive), if + If set to the string ``"default"``, this will try to find a "labels" key in the input (case-insensitive), if the input is a dict or it is a tuple whose second element is a dict. It can also be a callable that takes the same input as the transform, and returns either: @@ -511,14 +511,14 @@ class SanitizeKeyPoints(Transform): - A single tensor (the labels) - A tuple/list of tensors, each of which will be subject to the same sanitization as the keypoints. - If ``labels_getter`` is None then only keypoints are sanitized. + If ``labels_getter`` is None (the default), then only keypoints are sanitized. """ def __init__( self, min_valid_edge_distance: int = 0, min_invalid_points: int | float = 1, - labels_getter: Union[Callable[[Any], Any], str, None] = "default", + labels_getter: Union[Callable[[Any], Any], str, None] = None, ) -> None: super().__init__() self.min_valid_edge_distance = min_valid_edge_distance From d7e84920d77ff51b517f863c3284a9cebf5efd40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Tue, 7 Oct 2025 09:53:19 +0200 Subject: [PATCH 4/6] Minor docstring fix --- torchvision/transforms/v2/_misc.py | 2 +- torchvision/transforms/v2/functional/_misc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index e1c832f1eb7..1c0df9792e5 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -496,7 +496,7 @@ class SanitizeKeyPoints(Transform): min_invalid_points (int or float, optional): Minimum number or fraction of invalid keypoints required for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints. - If a float in (0.0, 1.0) is passed, it represents a fraction of the total number of keypoints in + If a float in ``(0.0, 1.0]`` is passed, it represents a fraction of the total number of keypoints in the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints. Note that a value of `1` (integer) is very different from `1.0` (float). The former will remove groups with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid. diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 8bac75f6ac2..ee36fc67745 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -480,7 +480,7 @@ def sanitize_keypoints( min_invalid_points (int or float, optional): Minimum number or fraction of invalid keypoints required for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints. - If a float in (0.0, 1.0) is passed, it represents a fraction of the total number of keypoints in + If a float in ``(0.0, 1.0]`` is passed, it represents a fraction of the total number of keypoints in the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints. Note that a value of `1` (integer) is very different from `1.0` (float). The former will remove groups with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid. From e24a03f1dca763768659f7cdc26cbc89f52d4eed Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Tue, 7 Oct 2025 11:40:05 -0700 Subject: [PATCH 5/6] fix type error Summary: Fixing error `TypeError: unsupported operand type(s) for |: 'type' and 'type'` --- torchvision/transforms/v2/functional/_misc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index ee36fc67745..e320abb246c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, Union import PIL.Image import torch @@ -448,7 +448,7 @@ def sanitize_keypoints( key_points: torch.Tensor, canvas_size: Optional[tuple[int, int]] = None, min_valid_edge_distance: int = 0, - min_invalid_points: int | float = 1, + min_invalid_points: Union[int, float] = 1, ) -> tuple[torch.Tensor, torch.Tensor]: """Remove keypoints outside of the image area and their corresponding labels (if any). @@ -527,7 +527,7 @@ def _get_sanitize_keypoints_mask( key_points: torch.Tensor, canvas_size: tuple[int, int], min_valid_edge_distance: int = 0, - min_invalid_points: int | float = 1, + min_invalid_points: Union[int, float] = 1, ) -> torch.Tensor: image_h, image_w = canvas_size From 4ba1abfdafb29f9ace8cb13871a43c80490d3525 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Tue, 7 Oct 2025 11:56:42 -0700 Subject: [PATCH 6/6] fix type error --- torchvision/transforms/v2/_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 1c0df9792e5..ad7ec6db087 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -517,7 +517,7 @@ class SanitizeKeyPoints(Transform): def __init__( self, min_valid_edge_distance: int = 0, - min_invalid_points: int | float = 1, + min_invalid_points: Union[int, float] = 1, labels_getter: Union[Callable[[Any], Any], str, None] = None, ) -> None: super().__init__()