66from torch .utils ._pytree import tree_flatten
77
88from ._tv_tensor import TVTensor
9+ from ._bounding_boxes import CLAMPING_MODE_TYPE
910
1011
1112class KeyPoints (TVTensor ):
@@ -43,6 +44,8 @@ class KeyPoints(TVTensor):
4344 :func:`torch.as_tensor`.
4445 canvas_size (two-tuple of ints): Height and width of the corresponding
4546 image or video.
47+ clamping_mode: The clamping mode to use when applying transforms that may result in key points
48+ outside of the image. Possible values are: "soft", "hard", or ``None``. Read more in :ref:`clamping_mode_tuto`.
4649 dtype (torch.dtype, optional): Desired data type of the bounding box. If
4750 omitted, will be inferred from ``data``.
4851 device (torch.device, optional): Desired device of the bounding box. If
@@ -55,29 +58,34 @@ class KeyPoints(TVTensor):
5558 """
5659
5760 canvas_size : tuple [int , int ]
61+ clamping_mode : CLAMPING_MODE_TYPE
5862
5963 @classmethod
60- def _wrap (cls , tensor : torch .Tensor , * , canvas_size : tuple [int , int ], check_dims : bool = True ) -> KeyPoints : # type: ignore[override]
64+ def _wrap (cls , tensor : torch .Tensor , * , canvas_size : tuple [int , int ], clamping_mode : CLAMPING_MODE_TYPE = "soft" , check_dims : bool = True ) -> KeyPoints : # type: ignore[override]
6165 if check_dims :
6266 if tensor .ndim == 1 :
6367 tensor = tensor .unsqueeze (0 )
6468 elif tensor .shape [- 1 ] != 2 :
6569 raise ValueError (f"Expected a tensor of shape (..., 2), not { tensor .shape } " )
70+ if clamping_mode is not None and clamping_mode not in ("hard" , "soft" ):
71+ raise ValueError (f"clamping_mode must be None, hard or soft, got { clamping_mode } ." )
6672 points = tensor .as_subclass (cls )
6773 points .canvas_size = canvas_size
74+ points .clamping_mode = clamping_mode
6875 return points
6976
7077 def __new__ (
7178 cls ,
7279 data : Any ,
7380 * ,
7481 canvas_size : tuple [int , int ],
82+ clamping_mode : CLAMPING_MODE_TYPE = "soft" ,
7583 dtype : torch .dtype | None = None ,
7684 device : torch .device | str | int | None = None ,
7785 requires_grad : bool | None = None ,
7886 ) -> KeyPoints :
7987 tensor = cls ._to_tensor (data , dtype = dtype , device = device , requires_grad = requires_grad )
80- return cls ._wrap (tensor , canvas_size = canvas_size )
88+ return cls ._wrap (tensor , canvas_size = canvas_size , clamping_mode = clamping_mode )
8189
8290 @classmethod
8391 def _wrap_output (
@@ -89,14 +97,14 @@ def _wrap_output(
8997 # Similar to BoundingBoxes._wrap_output(), see comment there.
9098 flat_params , _ = tree_flatten (args + (tuple (kwargs .values ()) if kwargs else ())) # type: ignore[operator]
9199 first_keypoints_from_args = next (x for x in flat_params if isinstance (x , KeyPoints ))
92- canvas_size = first_keypoints_from_args .canvas_size
100+ canvas_size , clamping_mode = first_keypoints_from_args .canvas_size , first_keypoints_from_args . clamping_mode
93101
94102 if isinstance (output , torch .Tensor ) and not isinstance (output , KeyPoints ):
95- output = KeyPoints ._wrap (output , canvas_size = canvas_size , check_dims = False )
103+ output = KeyPoints ._wrap (output , canvas_size = canvas_size , clamping_mode = clamping_mode , check_dims = False )
96104 elif isinstance (output , (tuple , list )):
97105 # This branch exists for chunk() and unbind()
98- output = type (output )(KeyPoints ._wrap (part , canvas_size = canvas_size , check_dims = False ) for part in output )
106+ output = type (output )(KeyPoints ._wrap (part , canvas_size = canvas_size , clamping_mode = clamping_mode , check_dims = False ) for part in output )
99107 return output
100108
101109 def __repr__ (self , * , tensor_contents : Any = None ) -> str : # type: ignore[override]
102- return self ._make_repr (canvas_size = self .canvas_size )
110+ return self ._make_repr (canvas_size = self .canvas_size , clamping_mode = self . clamping_mode )
0 commit comments