@@ -442,3 +442,128 @@ def _get_sanitize_bounding_boxes_mask(
442442 valid &= (bounding_boxes [..., 4 ] <= image_w ) & (bounding_boxes [..., 5 ] <= image_h )
443443 valid &= (bounding_boxes [..., 6 ] <= image_w ) & (bounding_boxes [..., 7 ] <= image_h )
444444 return valid
445+
446+
447+ def sanitize_keypoints (
448+ key_points : torch .Tensor ,
449+ canvas_size : Optional [tuple [int , int ]] = None ,
450+ min_valid_edge_distance : int = 0 ,
451+ min_invalid_points : int | float = 1 ,
452+ ) -> tuple [torch .Tensor , torch .Tensor ]:
453+ """Remove keypoints outside of the image area and their corresponding labels (if any).
454+
455+ This transform removes keypoints or groups of keypoints and their associated labels that
456+ have coordinates outside of their corresponding image or within ``min_valid_edge_distance`` pixels
457+ from the image edges.
458+ If you would instead like to clamp such keypoints to the image edges, use
459+ :class:`~torchvision.transforms.v2.ClampKeyPoints`.
460+
461+ It is recommended to call it at the end of a pipeline, before passing the
462+ input to the models.
463+
464+ Keypoints can be passed as a set of individual keypoints of shape ``[N_points, 2]`` or as a
465+ set of objects (e.g., polygons or polygonal chains) consisting of a fixed number of keypoints
466+ of shape ``[N_objects, ..., 2]``.
467+ When groups of keypoints are passed (i.e., an at least 3-dimensional tensor), this transform
468+ will only remove entire groups, not individual keypoints within a group.
469+
470+ Args:
471+ key_points (Tensor or :class:`~torchvision.tv_tensors.KeyPoints`): The keypoints to be sanitized.
472+ canvas_size (tuple of int, optional): The canvas_size of the keypoints
473+ (size of the corresponding image/video).
474+ Must be left to none if ``key_points`` is a :class:`~torchvision.tv_tensors.KeyPoints` object.
475+ min_valid_edge_distance (int, optional): The minimum distance that keypoints need to be away from the closest image
476+ edge along any axis in order to be considered valid. For example, setting this to 0 will only
477+ invalidate/remove keypoints outside of the image area, while a value of 1 will also remove keypoints
478+ lying exactly on the edge.
479+ Default is 0.
480+ min_invalid_points (int or float, optional): Minimum number or fraction of invalid keypoints required
481+ for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints
482+ if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints.
483+ If a float in (0.0, 1.0) is passed, it represents a fraction of the total number of keypoints in
484+ the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints.
485+ Note that a value of `1` (integer) is very different from `1.0` (float). The former will remove groups
486+ with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid.
487+ Default is 1.
488+
489+ Returns:
490+ out (tuple of Tensors): The subset of valid keypoints, and the corresponding indexing mask.
491+ The mask can then be used to subset other tensors (e.g. labels) that are associated with the keypoints.
492+ """
493+ if torch .jit .is_scripting () or is_pure_tensor (key_points ):
494+ if canvas_size is None :
495+ raise ValueError (
496+ "canvas_size cannot be None if key_points is a pure tensor. "
497+ "Set it to an appropriate value or pass key_points as a tv_tensors.KeyPoints object."
498+ )
499+ valid = _get_sanitize_keypoints_mask (
500+ key_points ,
501+ canvas_size = canvas_size ,
502+ min_valid_edge_distance = min_valid_edge_distance ,
503+ min_invalid_points = min_invalid_points ,
504+ )
505+ key_points = key_points [valid ]
506+ else :
507+ if not isinstance (key_points , tv_tensors .KeyPoints ):
508+ raise ValueError ("key_points must be a tv_tensors.KeyPoints instance or a pure tensor." )
509+ if canvas_size is not None :
510+ raise ValueError (
511+ "canvas_size must be None when key_points is a tv_tensors.KeyPoints instance. "
512+ f"Got canvas_size={ canvas_size } . "
513+ "Leave it to None or pass key_points as a pure tensor."
514+ )
515+ valid = _get_sanitize_keypoints_mask (
516+ key_points ,
517+ canvas_size = key_points .canvas_size ,
518+ min_valid_edge_distance = min_valid_edge_distance ,
519+ min_invalid_points = min_invalid_points ,
520+ )
521+ key_points = tv_tensors .wrap (key_points [valid ], like = key_points )
522+
523+ return key_points , valid
524+
525+
526+ def _get_sanitize_keypoints_mask (
527+ key_points : torch .Tensor ,
528+ canvas_size : tuple [int , int ],
529+ min_valid_edge_distance : int = 0 ,
530+ min_invalid_points : int | float = 1 ,
531+ ) -> torch .Tensor :
532+
533+ image_h , image_w = canvas_size
534+
535+ # Bring keypoints tensor into canonical shape [N_instances, N_points, 2]
536+ if key_points .ndim == 2 :
537+ key_points = key_points .unsqueeze (dim = 1 )
538+ elif key_points .ndim > 3 :
539+ key_points = key_points .flatten (start_dim = 1 , end_dim = - 2 )
540+
541+ # Convert min_invalid_points from relative to absolute number of points
542+ if min_invalid_points <= 0 :
543+ raise ValueError (f"min_invalid_points must be > 0. Got { min_invalid_points } ." )
544+ if isinstance (min_invalid_points , float ):
545+ min_invalid_points = math .ceil (min_invalid_points * key_points .shape [1 ])
546+ if min_invalid_points > 1 and key_points .shape [1 ] == 1 :
547+ raise ValueError (
548+ f"min_invalid_points was set to { min_invalid_points } , but key_points only contains a single point per "
549+ "instance, so min_invalid_points must be 1."
550+ )
551+
552+ # Compute distance of each point to the closest image edge
553+ dists = torch .stack (
554+ [
555+ key_points [..., 0 ], # x
556+ image_w - 1 - key_points [..., 0 ], # image_w - x
557+ key_points [..., 1 ], # y
558+ image_h - 1 - key_points [..., 1 ], # image_h - y
559+ ],
560+ dim = - 1 ,
561+ )
562+ dists = dists .min (dim = - 1 ).values # [N_instances, N_points]
563+
564+ # Determine invalid points
565+ invalid_points = dists < min_valid_edge_distance # [N_instances, N_points]
566+
567+ # Determine valid instances
568+ valid = invalid_points .sum (dim = - 1 ) < min_invalid_points # [N_instances]
569+ return valid
0 commit comments