Skip to content

Commit 73a40a8

Browse files
committed
Fixed docstring on sanitize_keypoints
1 parent 801e24d commit 73a40a8

File tree

1 file changed

+13
-4
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+13
-4
lines changed

torchvision/transforms/v2/functional/_misc.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,17 +334,24 @@ def sanitize_keypoints(
334334
"""Removes degenerate/invalid keypoints and returns the corresponding indexing mask.
335335
336336
This removes the keypoints that are outside of their corresponing image.
337-
You may want to first call :func:`~torchvision.transforms.v2.functional.clam_keypoints`
338-
first to avoid undesired removals.
337+
338+
It is recommended to call it at the end of a pipeline, before passing the
339+
input to the models. It is critical to call this transform if
340+
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
341+
If you want to be extra careful, you may call it after all transforms that
342+
may modify the key points but once at the end should be enough in most
343+
cases.
339344
340345
.. note::
341-
Points that touch the edge of the canvas are removed, unlike for :func:`sanitize_bounding_boxes`
346+
347+
Points that touch the edge of the canvas are removed, unlike for :func:`sanitize_bounding_boxes`.
342348
343349
Raises:
344350
ValueError: If the keypoints are not passed as a two dimensional tensor.
345351
346352
Args:
347-
keypoints (torch.Tensor or class:`~torchvision.tv_tensors.KeyPoints`): The Keypoints being removed
353+
keypoints (torch.Tensor or :class:`~torchvision.tv_tensors.KeyPoints`): The Keypoints being sanitized.
354+
Should be of shape ``[N, 2]``
348355
canvas_size (Optional[tuple[int, int]], optional): The canvas_size of the bounding boxes
349356
(size of the corresponding image/video).
350357
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.KeyPoints` object.
@@ -372,8 +379,10 @@ def sanitize_keypoints(
372379
canvas_size=canvas_size,
373380
)
374381
return keypoints[valid], valid
382+
375383
if not isinstance(keypoints, tv_tensors.KeyPoints):
376384
raise ValueError("keypoints must be a tv_tensors.KeyPoints instance or a pure tensor.")
385+
377386
valid = _get_sanitize_keypoints_mask(
378387
keypoints,
379388
canvas_size=keypoints.canvas_size,

0 commit comments

Comments
 (0)