Skip to content
Open
10 changes: 7 additions & 3 deletions torchvision/tv_tensors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TypeVar, cast

import torch

from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat
Expand All @@ -7,12 +9,14 @@
from ._tv_tensor import TVTensor
from ._video import Video

TVTensorLike = TypeVar("TVTensorLike", bound=TVTensor, covariant=True)


# TODO: Fix this. We skip this method as it leads to
# RecursionError: maximum recursion depth exceeded while calling a Python object
# Until `disable` is removed, there will be graph breaks after all calls to functional transforms
@torch.compiler.disable
def wrap(wrappee, *, like, **kwargs):
def wrap(wrappee: torch.Tensor, *, like: TVTensorLike, **kwargs) -> TVTensorLike: # type: ignore
"""Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.

If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of
Expand All @@ -26,10 +30,10 @@ def wrap(wrappee, *, like, **kwargs):
Ignored otherwise.
"""
if isinstance(like, BoundingBoxes):
return BoundingBoxes._wrap(
return cast(TVTensorLike, BoundingBoxes._wrap(
wrappee,
format=kwargs.get("format", like.format),
canvas_size=kwargs.get("canvas_size", like.canvas_size),
)
))
else:
return wrappee.as_subclass(type(like))