Skip to content

Commit d04a3e3

Browse files
committed
Proper fix for test_to_tv_tensor_reference
1 parent f03f958 commit d04a3e3

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

torchvision/tv_tensors/_bounding_boxes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def _wrap_output(
111111
if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes):
112112
output = BoundingBoxes._wrap(output, format=format, canvas_size=canvas_size, check_dims=False)
113113
elif isinstance(output, (tuple, list)):
114+
# This branch exists for chunk() and unbind()
114115
output = type(output)(
115116
BoundingBoxes._wrap(part, format=format, canvas_size=canvas_size, check_dims=False) for part in output
116117
)

torchvision/tv_tensors/_keypoints.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ class KeyPoints(TVTensor):
4242

4343
canvas_size: tuple[int, int]
4444

45+
@classmethod
46+
def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], check_dims: bool = True) -> KeyPoints: # type: ignore[override]
47+
if check_dims:
48+
if tensor.ndim == 1:
49+
tensor = tensor.unsqueeze(0)
50+
elif tensor.shape[-1] != 2:
51+
raise ValueError(f"Expected a tensor of shape (..., 2), not {tensor.shape}")
52+
points = tensor.as_subclass(cls)
53+
points.canvas_size = canvas_size
54+
return points
55+
4556
def __new__(
4657
cls,
4758
data: Any,
@@ -51,14 +62,8 @@ def __new__(
5162
device: torch.device | str | int | None = None,
5263
requires_grad: bool | None = None,
5364
) -> KeyPoints:
54-
tensor: torch.Tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
55-
if tensor.ndim == 1:
56-
tensor = tensor.unsqueeze(0)
57-
elif tensor.shape[-1] != 2:
58-
raise ValueError(f"Expected a tensor of shape (..., 2), not {tensor.shape}")
59-
points = tensor.as_subclass(cls)
60-
points.canvas_size = canvas_size
61-
return points
65+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
66+
return cls._wrap(tensor, canvas_size=canvas_size)
6267

6368
@classmethod
6469
def _wrap_output(
@@ -75,17 +80,10 @@ def _wrap_output(
7580
canvas_size = first_bbox_from_args.canvas_size
7681

7782
if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints):
78-
output = KeyPoints(output, canvas_size=canvas_size)
79-
elif isinstance(output, MutableSequence):
80-
# For lists and list-like object we don't try to create a new object, we just set the values in the list
81-
# This allows us to conserve the type of complex list-like object that may not follow the initialization API of lists
82-
for i, part in enumerate(output):
83-
output[i] = KeyPoints(part, canvas_size=canvas_size)
84-
elif isinstance(output, Sequence):
85-
# Non-mutable sequences handled here (like tuples)
86-
# Every sequence that is not a mutable sequence is a non-mutable sequence
87-
# We have to use a tuple here, since we know its initialization api, unlike for `output`
88-
output = tuple(KeyPoints(part, canvas_size=canvas_size) for part in output)
83+
output = KeyPoints._wrap(output, canvas_size=canvas_size, check_dims=False)
84+
elif isinstance(output, (tuple, list)):
85+
# This branch exists for chunk() and unbind()
86+
output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, check_dims=False) for part in output)
8987
return output
9088

9189
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]

0 commit comments

Comments
 (0)