@@ -42,6 +42,17 @@ class KeyPoints(TVTensor):
4242
4343 canvas_size : tuple [int , int ]
4444
45+ @classmethod
46+ def _wrap (cls , tensor : torch .Tensor , * , canvas_size : tuple [int , int ], check_dims : bool = True ) -> KeyPoints : # type: ignore[override]
47+ if check_dims :
48+ if tensor .ndim == 1 :
49+ tensor = tensor .unsqueeze (0 )
50+ elif tensor .shape [- 1 ] != 2 :
51+ raise ValueError (f"Expected a tensor of shape (..., 2), not { tensor .shape } " )
52+ points = tensor .as_subclass (cls )
53+ points .canvas_size = canvas_size
54+ return points
55+
4556 def __new__ (
4657 cls ,
4758 data : Any ,
@@ -51,14 +62,8 @@ def __new__(
5162 device : torch .device | str | int | None = None ,
5263 requires_grad : bool | None = None ,
5364 ) -> KeyPoints :
54- tensor : torch .Tensor = cls ._to_tensor (data , dtype = dtype , device = device , requires_grad = requires_grad )
55- if tensor .ndim == 1 :
56- tensor = tensor .unsqueeze (0 )
57- elif tensor .shape [- 1 ] != 2 :
58- raise ValueError (f"Expected a tensor of shape (..., 2), not { tensor .shape } " )
59- points = tensor .as_subclass (cls )
60- points .canvas_size = canvas_size
61- return points
65+ tensor = cls ._to_tensor (data , dtype = dtype , device = device , requires_grad = requires_grad )
66+ return cls ._wrap (tensor , canvas_size = canvas_size )
6267
6368 @classmethod
6469 def _wrap_output (
@@ -75,17 +80,10 @@ def _wrap_output(
7580 canvas_size = first_bbox_from_args .canvas_size
7681
7782 if isinstance (output , torch .Tensor ) and not isinstance (output , KeyPoints ):
78- output = KeyPoints (output , canvas_size = canvas_size )
79- elif isinstance (output , MutableSequence ):
80- # For lists and list-like object we don't try to create a new object, we just set the values in the list
81- # This allows us to conserve the type of complex list-like object that may not follow the initialization API of lists
82- for i , part in enumerate (output ):
83- output [i ] = KeyPoints (part , canvas_size = canvas_size )
84- elif isinstance (output , Sequence ):
85- # Non-mutable sequences handled here (like tuples)
86- # Every sequence that is not a mutable sequence is a non-mutable sequence
87- # We have to use a tuple here, since we know its initialization api, unlike for `output`
88- output = tuple (KeyPoints (part , canvas_size = canvas_size ) for part in output )
83+ output = KeyPoints ._wrap (output , canvas_size = canvas_size , check_dims = False )
84+ elif isinstance (output , (tuple , list )):
85+ # This branch exists for chunk() and unbind()
86+ output = type (output )(KeyPoints ._wrap (part , canvas_size = canvas_size , check_dims = False ) for part in output )
8987 return output
9088
9189 def __repr__ (self , * , tensor_contents : Any = None ) -> str : # type: ignore[override]
0 commit comments