@@ -328,82 +328,6 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo
328328 return inpt .to (dtype )
329329
330330
331- # TODOKP This is untested. Also there's no corresponding transform class
332- def sanitize_keypoints (
333- keypoints : torch .Tensor , canvas_size : Optional [tuple [int , int ]] = None
334- ) -> tuple [torch .Tensor , torch .Tensor ]:
335- """Removes degenerate/invalid keypoints and returns the corresponding indexing mask.
336-
337- This removes the keypoints that are outside of their corresponing image.
338-
339- It is recommended to call it at the end of a pipeline, before passing the
340- input to the models. It is critical to call this transform if
341- :class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
342- If you want to be extra careful, you may call it after all transforms that
343- may modify the key points but once at the end should be enough in most
344- cases.
345-
346- .. note::
347-
348- Points that touch the edge of the canvas are removed, unlike for :func:`sanitize_bounding_boxes`.
349- TODOKP Is this desirable? We probably want keypoints to behave the same as bboxes?
350-
351- Raises:
352- ValueError: If the keypoints are not passed as a two dimensional tensor.
353-
354- Args:
355- keypoints (torch.Tensor or :class:`~torchvision.tv_tensors.KeyPoints`): The Keypoints being sanitized.
356- Should be of shape ``[N, 2]``
357- canvas_size (Optional[tuple[int, int]], optional): The canvas_size of the bounding boxes
358- (size of the corresponding image/video).
359- Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.KeyPoints` object.
360-
361- Returns:
362- out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
363- The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes.
364- """
365- if not keypoints .ndim == 2 :
366- if keypoints .ndim < 2 :
367- raise ValueError ("Cannot sanitize a single Keypoint" )
368- raise ValueError (
369- "Cannot sanitize KeyPoints structure that are not 2D. "
370- f"Expected shape to be (N, 2), got { keypoints .shape } ({ keypoints .ndim = } , not 2)"
371- )
372- if torch .jit .is_scripting () or is_pure_tensor (keypoints ):
373- if canvas_size is None :
374- raise ValueError (
375- "canvas_size cannot be None if keypoints is a pure tensor. "
376- f"Got canvas_size={ canvas_size } ."
377- "Set that to appropriate values or pass keypoints as a tv_tensors.KeyPoints object."
378- )
379- valid = _get_sanitize_keypoints_mask (
380- keypoints ,
381- canvas_size = canvas_size ,
382- )
383- return keypoints [valid ], valid
384-
385- if not isinstance (keypoints , tv_tensors .KeyPoints ):
386- raise ValueError ("keypoints must be a tv_tensors.KeyPoints instance or a pure tensor." )
387-
388- valid = _get_sanitize_keypoints_mask (
389- keypoints ,
390- canvas_size = keypoints .canvas_size ,
391- )
392- return tv_tensors .wrap (keypoints [valid ], like = keypoints ), valid
393-
394-
395- # TODOKP Untested, see above
396- def _get_sanitize_keypoints_mask (
397- keypoints : torch .Tensor ,
398- canvas_size : tuple [int , int ],
399- ) -> torch .Tensor :
400- image_h , image_w = canvas_size
401- x = keypoints [:, 0 ]
402- y = keypoints [:, 1 ]
403-
404- return (0 < x ) & (x < image_w ) & (0 < y ) & (y < image_h )
405-
406-
407331def sanitize_bounding_boxes (
408332 bounding_boxes : torch .Tensor ,
409333 format : Optional [tv_tensors .BoundingBoxFormat ] = None ,
0 commit comments