Skip to content
Open
6 changes: 5 additions & 1 deletion test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
from torchvision.utils import _Image_fromarray


Expand Down Expand Up @@ -400,6 +400,10 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device)
x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device)
Expand Down
91 changes: 90 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
make_bounding_boxes,
make_detection_masks,
make_image,
make_image_cvcuda,
make_image_pil,
make_image_tensor,
make_keypoints,
Expand All @@ -51,8 +52,18 @@
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
from torchvision.transforms.v2.functional._utils import (
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
_register_kernel_internal,
)


CUDA_AVAILABLE = torch.cuda.is_available()
CVCUDA_AVAILABLE = _is_cvcuda_available()
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda()

# turns all warnings into errors for this module
pytestmark = [pytest.mark.filterwarnings("error")]
Expand Down Expand Up @@ -6733,6 +6744,84 @@ def test_functional_error(self):
F.pil_to_tensor(object())


@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
class TestToCVCUDATensor:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
@pytest.mark.parametrize(
"fn",
[F.to_cvcuda_tensor, transform_cls_to_functional(transforms.ToCVCUDATensor)],
)
def test_functional_and_transform(self, dtype, device, color_space, batch_dims, fn):
input = make_image_tensor(dtype=dtype, device=device, color_space=color_space, batch_dims=batch_dims)
output = fn(input)

assert isinstance(output, cvcuda.Tensor)
assert F.get_size(output) == F.get_size(input)
assert output is not None

def test_invalid_input_type(self):
with pytest.raises(TypeError, match=r"inpt should be `torch.Tensor`"):
F.to_cvcuda_tensor("invalid_input")

def test_invalid_dimensions(self):
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
img_data = torch.randint(0, 256, (4,), dtype=torch.uint8)
img_data = img_data.cuda()
F.to_cvcuda_tensor(img_data)

with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8)
img_data = img_data.cuda()
F.to_cvcuda_tensor(img_data)

with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
img_data = torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8)
img_data = img_data.cuda()
F.to_cvcuda_tensor(img_data)

@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_round_trip(self, dtype, device, color_space, batch_size):
original_tensor = make_image_tensor(
dtype=dtype, device=device, color_space=color_space, batch_dims=(batch_size,)
)
cvcuda_tensor = F.to_cvcuda_tensor(original_tensor)
result_tensor = F.cvcuda_to_tensor(cvcuda_tensor)
torch.testing.assert_close(result_tensor.to(device), original_tensor, rtol=0, atol=0)
assert result_tensor.shape[0] == batch_size


@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
class TestCVDUDAToTensor:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
@pytest.mark.parametrize(
"fn",
[F.cvcuda_to_tensor, transform_cls_to_functional(transforms.CVCUDAToTensor)],
)
def test_functional_and_transform(self, dtype, device, color_space, batch_dims, fn):
input = make_image_cvcuda(dtype=dtype, device=device, color_space=color_space, batch_dims=batch_dims)

output = fn(input)

assert isinstance(output, torch.Tensor)
input_tensor = F.cvcuda_to_tensor(input)
assert F.get_size(output) == F.get_size(input_tensor)

def test_functional_error(self):
with pytest.raises(TypeError, match="cvcuda_img should be `cvcuda.Tensor`"):
F.cvcuda_to_tensor(object())


class TestLambda:
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
from ._type_conversion import CVCUDAToTensor, PILToTensor, ToCVCUDATensor, ToImage, ToPILImage, ToPureTensor
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size

from ._deprecated import ToTensor # usort: skip
52 changes: 50 additions & 2 deletions torchvision/transforms/v2/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import numpy as np
import PIL.Image
import torch

from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform

from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2.functional._utils import _import_cvcuda

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]


class PILToTensor(Transform):
Expand Down Expand Up @@ -90,3 +93,48 @@ class ToPureTensor(Transform):

def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor:
return inpt.as_subclass(torch.Tensor)


class ToCVCUDATensor(Transform):
"""Convert a `torch.Tensor` with NCHW layout to a `cvcuda.Tensor` with NHWC layout.
If the input tensor is on CPU, it will automatically be transferred to GPU.
Only 1-channel and 3-channel images are supported.

This transform does not support torchscript.

Example:
>>> import torch
>>> from torchvision.transforms import v2
>>> img_tensor = torch.randint(0, 256, (1, 3, 320, 240), dtype=torch.uint8)
>>> img_cvcuda = v2.ToCVCUDATensor()(img_tensor)
>>> print(img_cvcuda.shape)
(1, 3, 240, 320)
"""

def transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> "cvcuda.Tensor":
return F.to_cvcuda_tensor(inpt)


class CVCUDAToTensor(Transform):
"""Convert a `cvcuda.Tensor` with NHWC layout to a `torch.Tensor` with NCHW layout.

This function does not support torchscript.

Example:
>>> import cvcuda
>>> from torchvision.transforms import v2
>>> img_tensor = torch.randint(0, 255, (1, 240, 320, 3), dtype=torch.uint8, device="cuda")
>>> img_cvcuda = cvcuda.as_tensor(img_tensor, cvcuda.TensorLayout.NHWC)
>>> img_tensor = v2.CVCUDAToTensor()(img_cvcuda)
>>> print(img_tensor.shape)
torch.Size([1, 3, 240, 320])
"""

try:
cvcuda = _import_cvcuda()
_transformed_types = (cvcuda.Tensor,)
except ImportError:
pass

def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor:
return F.cvcuda_to_tensor(inpt)
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,6 @@
to_dtype_video,
)
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image, to_pil_image
from ._type_conversion import cvcuda_to_tensor, pil_to_tensor, to_cvcuda_tensor, to_image, to_pil_image

from ._deprecated import get_image_size, to_tensor # usort: skip
25 changes: 23 additions & 2 deletions torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, TYPE_CHECKING, Union

import PIL.Image
import torch
Expand All @@ -9,7 +9,14 @@

from torchvision.utils import _log_api_usage_once

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor

CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def get_dimensions(inpt: torch.Tensor) -> list[int]:
Expand Down Expand Up @@ -107,6 +114,20 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]:
return [height, width]


def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
"""Get size of `cvcuda.Tensor` with NHWC layout."""
hw = list(image.shape[-3:-1])
ndims = len(hw)
if ndims == 2:
return hw
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


if CVCUDA_AVAILABLE:
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)


@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
def get_size_video(video: torch.Tensor) -> list[int]:
return get_size_image(video)
Expand Down
39 changes: 38 additions & 1 deletion torchvision/transforms/v2/functional/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Union
from typing import TYPE_CHECKING, Union

import numpy as np
import PIL.Image
import torch
from torchvision import tv_tensors
from torchvision.transforms import functional as _F
from torchvision.utils import _log_api_usage_once

from ._utils import _import_cvcuda

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]


@torch.jit.unused
Expand All @@ -25,3 +31,34 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso

to_pil_image = _F.to_pil_image
pil_to_tensor = _F.pil_to_tensor


@torch.jit.unused
def to_cvcuda_tensor(inpt: torch.Tensor) -> "cvcuda.Tensor":
"""See :class:`~torchvision.transforms.v2.ToCVCUDATensor` for details."""
cvcuda = _import_cvcuda()
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_cvcuda_tensor)
if not isinstance(inpt, torch.Tensor):
raise TypeError(f"inpt should be `torch.Tensor`. Got {type(inpt)}.")
if inpt.ndim not in [3, 4]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed from the changes summary that non-batched image support is removed, from that I think this check should be checking that inpt.ndim == 4

Same observation for cvcuda_to_tensor

raise ValueError(f"pic should be 3 or 4 dimensional. Got {inpt.ndim} dimensions.")
# Convert to NHWC as CVCUDA transforms do not support NCHW
inpt = inpt.permute(0, 2, 3, 1)
return cvcuda.as_tensor(inpt.cuda().contiguous(), cvcuda.TensorLayout.NHWC)


@torch.jit.unused
def cvcuda_to_tensor(cvcuda_img: "cvcuda.Tensor") -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.CVCUDAToTensor` for details."""
cvcuda = _import_cvcuda()
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(cvcuda_to_tensor)
if not isinstance(cvcuda_img, cvcuda.Tensor):
raise TypeError(f"cvcuda_img should be `cvcuda.Tensor`. Got {type(cvcuda_img)}.")
cuda_tensor = torch.as_tensor(cvcuda_img.cuda(), device="cuda")
if cvcuda_img.ndim not in [3, 4]:
raise ValueError(f"Image should be 3 or 4 dimensional. Got {cuda_tensor.ndim} dimensions.")
# Convert to NCHW layout from CVCUDA default NHWC
img = cuda_tensor.permute(0, 3, 1, 2)
return img
29 changes: 29 additions & 0 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,32 @@ def decorator(kernel):
return kernel

return decorator


def _import_cvcuda():
"""Import CV-CUDA modules with informative error message if not installed.

Returns:
cvcuda module.

Raises:
RuntimeError: If CV-CUDA is not installed.
"""
try:
import cvcuda # type: ignore[import-not-found]

return cvcuda
except ImportError as e:
raise ImportError(
"CV-CUDA is required but not installed. "
"Please install it following the instructions at "
"https://github.com/CVCUDA/CV-CUDA."
) from e


def _is_cvcuda_available():
try:
_ = _import_cvcuda()
return True
except ImportError:
return False