Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ Miscellaneous
v2.RandomErasing
v2.Lambda
v2.SanitizeBoundingBoxes
v2.SanitizeKeyPoints
v2.ClampBoundingBoxes
v2.ClampKeyPoints
v2.UniformTemporalSubsample
Expand All @@ -427,6 +428,7 @@ Functionals
v2.functional.normalize
v2.functional.erase
v2.functional.sanitize_bounding_boxes
v2.functional.sanitize_keypoints
v2.functional.clamp_bounding_boxes
v2.functional.clamp_keypoints
v2.functional.uniform_temporal_subsample
Expand Down Expand Up @@ -530,6 +532,7 @@ Developer tools
v2.query_size
v2.query_chw
v2.get_bounding_boxes
v2.get_keypoints


V1 API Reference
Expand Down
337 changes: 325 additions & 12 deletions test/test_transforms_v2.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@
LinearTransformation,
Normalize,
SanitizeBoundingBoxes,
SanitizeKeyPoints,
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size

from ._deprecated import ToTensor # usort: skip
100 changes: 99 additions & 1 deletion torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform

from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
from ._utils import (
_parse_labels_getter,
_setup_number_or_seq,
_setup_size,
get_bounding_boxes,
get_keypoints,
has_any,
is_pure_tensor,
)


# TODO: do we want/need to expose this?
Expand Down Expand Up @@ -459,3 +467,93 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return output
else:
return tv_tensors.wrap(output, like=inpt)


class SanitizeKeyPoints(Transform):
"""Remove keypoints outside of the image area and their corresponding labels (if any).

This transform removes keypoints or groups of keypoints and their associated labels that
have coordinates outside of their corresponding image.
If you would instead like to clamp such keypoints to the image edges, use
:class:`~torchvision.transforms.v2.ClampKeyPoints`.

It is recommended to call it at the end of a pipeline, before passing the
input to the models.

Keypoints can be passed as a set of individual keypoints or as a set of objects
(e.g., polygons or polygonal chains) consisting of a fixed number of keypoints of shape ``[..., 2]``.
When groups of keypoints are passed (i.e., an at least 3-dimensional tensor), this transform
will only remove entire groups, not individual keypoints within a group.

Args:
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
(or anything else that needs to be sanitized along with the keypoints).
If set to the string ``"default"``, this will try to find a "labels" key in the input (case-insensitive), if
the input is a dict or it is a tuple whose second element is a dict.

It can also be a callable that takes the same input as the transform, and returns either:

- A single tensor (the labels)
- A tuple/list of tensors, each of which will be subject to the same sanitization as the keypoints.

If ``labels_getter`` is None (the default), then only keypoints are sanitized.
"""

def __init__(
self,
labels_getter: Union[Callable[[Any], Any], str, None] = None,
) -> None:
super().__init__()
self.labels_getter = labels_getter
self._labels_getter = _parse_labels_getter(labels_getter)

def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]

labels = self._labels_getter(inputs)
if labels is not None:
msg = "The labels in the input to forward() must be a tensor or None, got {type} instead."
if isinstance(labels, torch.Tensor):
labels = (labels,)
elif isinstance(labels, (tuple, list)):
for entry in labels:
if not isinstance(entry, torch.Tensor):
# TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask]
raise ValueError(msg.format(type=type(entry)))
else:
raise ValueError(msg.format(type=type(labels)))

flat_inputs, spec = tree_flatten(inputs)
points = get_keypoints(flat_inputs)

if labels is not None:
for label in labels:
if points.shape[0] != label.shape[0]:
raise ValueError(
f"Number of kepyoints (shape={points.shape}) must match the number of labels."
f"Found labels with shape={label.shape})."
)

valid = F._misc._get_sanitize_keypoints_mask(
points,
canvas_size=points.canvas_size,
)

params = dict(valid=valid, labels=labels)
flat_outputs = [self.transform(inpt, params) for inpt in flat_inputs]

return tree_unflatten(flat_outputs, spec)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
is_keypoints = isinstance(inpt, tv_tensors.KeyPoints)

if not (is_label or is_keypoints):
return inpt

output = inpt[params["valid"]]

if is_label:
return output
else:
return tv_tensors.wrap(output, like=inpt)
12 changes: 12 additions & 0 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,18 @@ def get_bounding_boxes(flat_inputs: list[Any]) -> tv_tensors.BoundingBoxes:
raise ValueError("No bounding boxes were found in the sample")


def get_keypoints(flat_inputs: list[Any]) -> tv_tensors.KeyPoints:
"""Return the keypoints in the input.

Assumes only one ``KeyPoints`` object is present.
"""
# This assumes there is only one keypoint per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.KeyPoints))
except StopIteration:
raise ValueError("No keypoints were found in the sample")


def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
"""Return Channel, Height, and Width."""
chws = {
Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
normalize_image,
normalize_video,
sanitize_bounding_boxes,
sanitize_keypoints,
to_dtype,
to_dtype_image,
to_dtype_video,
Expand Down
16 changes: 8 additions & 8 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from torchvision.utils import _log_api_usage_once

from ._meta import _get_size_image_pil, clamp_bounding_boxes, clamp_keypoints, convert_bounding_box_format
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format

from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal

Expand Down Expand Up @@ -71,7 +71,7 @@ def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, i
shape = keypoints.shape
keypoints = keypoints.clone().reshape(-1, 2)
keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_()
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
return keypoints.reshape(shape)


@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down Expand Up @@ -159,7 +159,7 @@ def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int
shape = keypoints.shape
keypoints = keypoints.clone().reshape(-1, 2)
keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_()
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
return keypoints.reshape(shape)


def vertical_flip_bounding_boxes(
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def _affine_keypoints_with_expand(
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
canvas_size = (new_height, new_width)

out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size).reshape(original_shape)
out_keypoints = transformed_points.reshape(original_shape)
out_keypoints = out_keypoints.to(original_dtype)

return out_keypoints, canvas_size
Expand Down Expand Up @@ -1695,7 +1695,7 @@ def pad_keypoints(
left, right, top, bottom = _parse_pad_padding(padding)
pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right)
return clamp_keypoints(keypoints + pad, canvas_size), canvas_size
return keypoints + pad, canvas_size


@_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down Expand Up @@ -1817,7 +1817,7 @@ def crop_keypoints(
keypoints = keypoints - torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
canvas_size = (height, width)

return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size
return keypoints, canvas_size


@_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down Expand Up @@ -2047,7 +2047,7 @@ def perspective_keypoints(
numer_points = torch.matmul(points, theta1.T)
denom_points = torch.matmul(points, theta2.T)
transformed_points = numer_points.div_(denom_points)
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size).reshape(original_shape)
return transformed_points.to(keypoints.dtype).reshape(original_shape)


@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down Expand Up @@ -2376,7 +2376,7 @@ def elastic_keypoints(
t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)

return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size).reshape(original_shape)
return transformed_points.to(keypoints.dtype).reshape(original_shape)


@_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
Expand Down
73 changes: 73 additions & 0 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,76 @@ def _get_sanitize_bounding_boxes_mask(
valid &= (bounding_boxes[..., 4] <= image_w) & (bounding_boxes[..., 5] <= image_h)
valid &= (bounding_boxes[..., 6] <= image_w) & (bounding_boxes[..., 7] <= image_h)
return valid


def sanitize_keypoints(
key_points: torch.Tensor,
canvas_size: Optional[tuple[int, int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Remove keypoints outside of the image area and their corresponding labels (if any).

This transform removes keypoints or groups of keypoints and their associated labels that
have coordinates outside of their corresponding image.
If you would instead like to clamp such keypoints to the image edges, use
:class:`~torchvision.transforms.v2.ClampKeyPoints`.

It is recommended to call it at the end of a pipeline, before passing the
input to the models.

Keypoints can be passed as a set of individual keypoints or as a set of objects
(e.g., polygons or polygonal chains) consisting of a fixed number of keypoints of shape ``[..., 2]``.
When groups of keypoints are passed (i.e., an at least 3-dimensional tensor),
this transform will only remove entire groups, not individual keypoints within a group.

Args:
key_points (Tensor or :class:`~torchvision.tv_tensors.KeyPoints`): The keypoints to be sanitized.
canvas_size (tuple of int, optional): The canvas_size of the keypoints
(size of the corresponding image/video).
Must be left to none if ``key_points`` is a :class:`~torchvision.tv_tensors.KeyPoints` object.

Returns:
out (tuple of Tensors): The subset of valid keypoints, and the corresponding indexing mask.
The mask can then be used to subset other tensors (e.g. labels) that are associated with the keypoints.
"""
if torch.jit.is_scripting() or is_pure_tensor(key_points):
if canvas_size is None:
raise ValueError(
"canvas_size cannot be None if key_points is a pure tensor. "
"Set it to an appropriate value or pass key_points as a tv_tensors.KeyPoints object."
)
valid = _get_sanitize_keypoints_mask(
key_points,
canvas_size=canvas_size,
)
key_points = key_points[valid]
else:
if not isinstance(key_points, tv_tensors.KeyPoints):
raise ValueError("key_points must be a tv_tensors.KeyPoints instance or a pure tensor.")
if canvas_size is not None:
raise ValueError(
"canvas_size must be None when key_points is a tv_tensors.KeyPoints instance. "
f"Got canvas_size={canvas_size}. "
"Leave it to None or pass key_points as a pure tensor."
)
valid = _get_sanitize_keypoints_mask(
key_points,
canvas_size=key_points.canvas_size,
)
key_points = tv_tensors.wrap(key_points[valid], like=key_points)

return key_points, valid


def _get_sanitize_keypoints_mask(
key_points: torch.Tensor,
canvas_size: tuple[int, int],
) -> torch.Tensor:

h, w = canvas_size

x, y = key_points[..., 0], key_points[..., 1]
valid = (x >= 0) & (x < w) & (y >= 0) & (y < h)

valid = valid.flatten(start_dim=1).all(dim=1) if valid.ndim > 1 else valid

return valid
Loading