Skip to content

Commit 51a6574

Browse files
committed
Add SanitizeKeyPoints transform
1 parent 7a13ad0 commit 51a6574

File tree

6 files changed

+266
-2
lines changed

6 files changed

+266
-2
lines changed

docs/source/transforms.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ Miscellaneous
413413
v2.RandomErasing
414414
v2.Lambda
415415
v2.SanitizeBoundingBoxes
416+
v2.SanitizeKeyPoints
416417
v2.ClampBoundingBoxes
417418
v2.ClampKeyPoints
418419
v2.UniformTemporalSubsample
@@ -427,6 +428,7 @@ Functionals
427428
v2.functional.normalize
428429
v2.functional.erase
429430
v2.functional.sanitize_bounding_boxes
431+
v2.functional.sanitize_keypoints
430432
v2.functional.clamp_bounding_boxes
431433
v2.functional.clamp_keypoints
432434
v2.functional.uniform_temporal_subsample
@@ -530,6 +532,7 @@ Developer tools
530532
v2.query_size
531533
v2.query_chw
532534
v2.get_bounding_boxes
535+
v2.get_keypoints
533536

534537

535538
V1 API Reference

torchvision/transforms/v2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@
5151
LinearTransformation,
5252
Normalize,
5353
SanitizeBoundingBoxes,
54+
SanitizeKeyPoints,
5455
ToDtype,
5556
)
5657
from ._temporal import UniformTemporalSubsample
5758
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
58-
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
59+
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size
5960

6061
from ._deprecated import ToTensor # usort: skip

torchvision/transforms/v2/_misc.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
from torchvision import transforms as _transforms, tv_tensors
1111
from torchvision.transforms.v2 import functional as F, Transform
1212

13-
from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
13+
from ._utils import (
14+
_parse_labels_getter,
15+
_setup_number_or_seq,
16+
_setup_size,
17+
get_bounding_boxes,
18+
get_keypoints,
19+
has_any,
20+
is_pure_tensor,
21+
)
1422

1523

1624
# TODO: do we want/need to expose this?
@@ -459,3 +467,117 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
459467
return output
460468
else:
461469
return tv_tensors.wrap(output, like=inpt)
470+
471+
472+
class SanitizeKeyPoints(Transform):
473+
"""Remove keypoints outside of the image area and their corresponding labels (if any).
474+
475+
This transform removes keypoints or groups of keypoints and their associated labels that
476+
have coordinates outside of their corresponding image or within ``min_valid_edge_distance`` pixels
477+
from the image edges.
478+
If you would instead like to clamp such keypoints to the image edges, use
479+
:class:`~torchvision.transforms.v2.ClampKeyPoints`.
480+
481+
It is recommended to call it at the end of a pipeline, before passing the
482+
input to the models.
483+
484+
Keypoints can be passed as a set of individual keypoints of shape ``[N_points, 2]`` or as a
485+
set of objects (e.g., polygons or polygonal chains) consisting of a fixed number of keypoints
486+
of shape ``[N_objects, ..., 2]``.
487+
When groups of keypoints are passed (i.e., an at least 3-dimensional tensor), this transform
488+
will only remove entire groups, not individual keypoints within a group.
489+
490+
Args:
491+
min_valid_edge_distance (int, optional): The minimum distance that keypoints need to be away from the closest image
492+
edge along any axis in order to be considered valid. For example, setting this to 0 will only
493+
invalidate/remove keypoints outside of the image area, while a value of 1 will also remove keypoints
494+
lying exactly on the edge.
495+
Default is 0.
496+
min_invalid_points (int or float, optional): Minimum number or fraction of invalid keypoints required
497+
for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints
498+
if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints.
499+
If a float in (0.0, 1.0) is passed, it represents a fraction of the total number of keypoints in
500+
the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints.
501+
Note that a value of `1` (integer) is very different from `1.0` (float). The former will remove groups
502+
with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid.
503+
Default is 1.
504+
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
505+
(or anything else that needs to be sanitized along with the keypoints).
506+
By default, this will try to find a "labels" key in the input (case-insensitive), if
507+
the input is a dict or it is a tuple whose second element is a dict.
508+
509+
It can also be a callable that takes the same input as the transform, and returns either:
510+
511+
- A single tensor (the labels)
512+
- A tuple/list of tensors, each of which will be subject to the same sanitization as the keypoints.
513+
514+
If ``labels_getter`` is None then only keypoints are sanitized.
515+
"""
516+
517+
def __init__(
518+
self,
519+
min_valid_edge_distance: int = 0,
520+
min_invalid_points: int | float = 1,
521+
labels_getter: Union[Callable[[Any], Any], str, None] = "default",
522+
) -> None:
523+
super().__init__()
524+
self.min_valid_edge_distance = min_valid_edge_distance
525+
self.min_invalid_points = min_invalid_points
526+
self.labels_getter = labels_getter
527+
self._labels_getter = _parse_labels_getter(labels_getter)
528+
529+
if min_invalid_points <= 0:
530+
raise ValueError(f"min_invalid_points must be > 0. Got {min_invalid_points}.")
531+
532+
def forward(self, *inputs: Any) -> Any:
533+
inputs = inputs if len(inputs) > 1 else inputs[0]
534+
535+
labels = self._labels_getter(inputs)
536+
if labels is not None:
537+
msg = "The labels in the input to forward() must be a tensor or None, got {type} instead."
538+
if isinstance(labels, torch.Tensor):
539+
labels = (labels,)
540+
elif isinstance(labels, (tuple, list)):
541+
for entry in labels:
542+
if not isinstance(entry, torch.Tensor):
543+
# TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask]
544+
raise ValueError(msg.format(type=type(entry)))
545+
else:
546+
raise ValueError(msg.format(type=type(labels)))
547+
548+
flat_inputs, spec = tree_flatten(inputs)
549+
points = get_keypoints(flat_inputs)
550+
551+
if labels is not None:
552+
for label in labels:
553+
if points.shape[0] != label.shape[0]:
554+
raise ValueError(
555+
f"Number of kepyoints (shape={points.shape}) must match the number of labels."
556+
f"Found labels with shape={label.shape})."
557+
)
558+
559+
valid = F._misc._get_sanitize_keypoints_mask(
560+
points,
561+
canvas_size=points.canvas_size,
562+
min_valid_edge_distance=self.min_valid_edge_distance,
563+
min_invalid_points=self.min_invalid_points,
564+
)
565+
566+
params = dict(valid=valid, labels=labels)
567+
flat_outputs = [self.transform(inpt, params) for inpt in flat_inputs]
568+
569+
return tree_unflatten(flat_outputs, spec)
570+
571+
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
572+
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
573+
is_keypoints = isinstance(inpt, tv_tensors.KeyPoints)
574+
575+
if not (is_label or is_keypoints):
576+
return inpt
577+
578+
output = inpt[params["valid"]]
579+
580+
if is_label:
581+
return output
582+
else:
583+
return tv_tensors.wrap(output, like=inpt)

torchvision/transforms/v2/_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,18 @@ def get_bounding_boxes(flat_inputs: list[Any]) -> tv_tensors.BoundingBoxes:
165165
raise ValueError("No bounding boxes were found in the sample")
166166

167167

168+
def get_keypoints(flat_inputs: list[Any]) -> tv_tensors.KeyPoints:
169+
"""Return the keypoints in the input.
170+
171+
Assumes only one ``KeyPoints`` object is present.
172+
"""
173+
# This assumes there is only one keypoint per sample as per the general convention
174+
try:
175+
return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.KeyPoints))
176+
except StopIteration:
177+
raise ValueError("No keypoints were found in the sample")
178+
179+
168180
def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
169181
"""Return Channel, Height, and Width."""
170182
chws = {

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@
156156
normalize_image,
157157
normalize_video,
158158
sanitize_bounding_boxes,
159+
sanitize_keypoints,
159160
to_dtype,
160161
to_dtype_image,
161162
to_dtype_video,

torchvision/transforms/v2/functional/_misc.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,128 @@ def _get_sanitize_bounding_boxes_mask(
442442
valid &= (bounding_boxes[..., 4] <= image_w) & (bounding_boxes[..., 5] <= image_h)
443443
valid &= (bounding_boxes[..., 6] <= image_w) & (bounding_boxes[..., 7] <= image_h)
444444
return valid
445+
446+
447+
def sanitize_keypoints(
448+
key_points: torch.Tensor,
449+
canvas_size: Optional[tuple[int, int]] = None,
450+
min_valid_edge_distance: int = 0,
451+
min_invalid_points: int | float = 1,
452+
) -> tuple[torch.Tensor, torch.Tensor]:
453+
"""Remove keypoints outside of the image area and their corresponding labels (if any).
454+
455+
This transform removes keypoints or groups of keypoints and their associated labels that
456+
have coordinates outside of their corresponding image or within ``min_valid_edge_distance`` pixels
457+
from the image edges.
458+
If you would instead like to clamp such keypoints to the image edges, use
459+
:class:`~torchvision.transforms.v2.ClampKeyPoints`.
460+
461+
It is recommended to call it at the end of a pipeline, before passing the
462+
input to the models.
463+
464+
Keypoints can be passed as a set of individual keypoints of shape ``[N_points, 2]`` or as a
465+
set of objects (e.g., polygons or polygonal chains) consisting of a fixed number of keypoints
466+
of shape ``[N_objects, ..., 2]``.
467+
When groups of keypoints are passed (i.e., an at least 3-dimensional tensor), this transform
468+
will only remove entire groups, not individual keypoints within a group.
469+
470+
Args:
471+
key_points (Tensor or :class:`~torchvision.tv_tensors.KeyPoints`): The keypoints to be sanitized.
472+
canvas_size (tuple of int, optional): The canvas_size of the keypoints
473+
(size of the corresponding image/video).
474+
Must be left to none if ``key_points`` is a :class:`~torchvision.tv_tensors.KeyPoints` object.
475+
min_valid_edge_distance (int, optional): The minimum distance that keypoints need to be away from the closest image
476+
edge along any axis in order to be considered valid. For example, setting this to 0 will only
477+
invalidate/remove keypoints outside of the image area, while a value of 1 will also remove keypoints
478+
lying exactly on the edge.
479+
Default is 0.
480+
min_invalid_points (int or float, optional): Minimum number or fraction of invalid keypoints required
481+
for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints
482+
if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints.
483+
If a float in (0.0, 1.0) is passed, it represents a fraction of the total number of keypoints in
484+
the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints.
485+
Note that a value of `1` (integer) is very different from `1.0` (float). The former will remove groups
486+
with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid.
487+
Default is 1.
488+
489+
Returns:
490+
out (tuple of Tensors): The subset of valid keypoints, and the corresponding indexing mask.
491+
The mask can then be used to subset other tensors (e.g. labels) that are associated with the keypoints.
492+
"""
493+
if torch.jit.is_scripting() or is_pure_tensor(key_points):
494+
if canvas_size is None:
495+
raise ValueError(
496+
"canvas_size cannot be None if key_points is a pure tensor. "
497+
"Set it to an appropriate value or pass key_points as a tv_tensors.KeyPoints object."
498+
)
499+
valid = _get_sanitize_keypoints_mask(
500+
key_points,
501+
canvas_size=canvas_size,
502+
min_valid_edge_distance=min_valid_edge_distance,
503+
min_invalid_points=min_invalid_points,
504+
)
505+
key_points = key_points[valid]
506+
else:
507+
if not isinstance(key_points, tv_tensors.KeyPoints):
508+
raise ValueError("key_points must be a tv_tensors.KeyPoints instance or a pure tensor.")
509+
if canvas_size is not None:
510+
raise ValueError(
511+
"canvas_size must be None when key_points is a tv_tensors.KeyPoints instance. "
512+
f"Got canvas_size={canvas_size}. "
513+
"Leave it to None or pass key_points as a pure tensor."
514+
)
515+
valid = _get_sanitize_keypoints_mask(
516+
key_points,
517+
canvas_size=key_points.canvas_size,
518+
min_valid_edge_distance=min_valid_edge_distance,
519+
min_invalid_points=min_invalid_points,
520+
)
521+
key_points = tv_tensors.wrap(key_points[valid], like=key_points)
522+
523+
return key_points, valid
524+
525+
526+
def _get_sanitize_keypoints_mask(
527+
key_points: torch.Tensor,
528+
canvas_size: tuple[int, int],
529+
min_valid_edge_distance: int = 0,
530+
min_invalid_points: int | float = 1,
531+
) -> torch.Tensor:
532+
533+
image_h, image_w = canvas_size
534+
535+
# Bring keypoints tensor into canonical shape [N_instances, N_points, 2]
536+
if key_points.ndim == 2:
537+
key_points = key_points.unsqueeze(dim=1)
538+
elif key_points.ndim > 3:
539+
key_points = key_points.flatten(start_dim=1, end_dim=-2)
540+
541+
# Convert min_invalid_points from relative to absolute number of points
542+
if min_invalid_points <= 0:
543+
raise ValueError(f"min_invalid_points must be > 0. Got {min_invalid_points}.")
544+
if isinstance(min_invalid_points, float):
545+
min_invalid_points = math.ceil(min_invalid_points * key_points.shape[1])
546+
if min_invalid_points > 1 and key_points.shape[1] == 1:
547+
raise ValueError(
548+
f"min_invalid_points was set to {min_invalid_points}, but key_points only contains a single point per "
549+
"instance, so min_invalid_points must be 1."
550+
)
551+
552+
# Compute distance of each point to the closest image edge
553+
dists = torch.stack(
554+
[
555+
key_points[..., 0], # x
556+
image_w - 1 - key_points[..., 0], # image_w - x
557+
key_points[..., 1], # y
558+
image_h - 1 - key_points[..., 1], # image_h - y
559+
],
560+
dim=-1,
561+
)
562+
dists = dists.min(dim=-1).values # [N_instances, N_points]
563+
564+
# Determine invalid points
565+
invalid_points = dists < min_valid_edge_distance # [N_instances, N_points]
566+
567+
# Determine valid instances
568+
valid = invalid_points.sum(dim=-1) < min_invalid_points # [N_instances]
569+
return valid

0 commit comments

Comments
 (0)