-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Introducing CVCUDA Backend #9259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
2de2e98
481a9cd
160d320
6b42c00
722d731
fa92fbd
61dbcb0
e97e9ed
a5d53e8
3d3b0d6
415f136
b4cfa3f
9f4e8ab
ca7b9a1
1d05f58
3af1d6b
8984188
358cf81
00d4484
ccf1a36
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
| make_bounding_boxes, | ||
| make_detection_masks, | ||
| make_image, | ||
| make_image_cvcuda, | ||
| make_image_pil, | ||
| make_image_tensor, | ||
| make_keypoints, | ||
|
|
@@ -51,8 +52,16 @@ | |
| from torchvision.transforms.v2 import functional as F | ||
| from torchvision.transforms.v2._utils import check_type, is_pure_tensor | ||
| from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes | ||
| from torchvision.transforms.v2.functional._type_conversion import _import_cvcuda_modules | ||
| from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal | ||
|
|
||
| try: | ||
| _import_cvcuda_modules() | ||
| CVCUDA_AVAILABLE = True | ||
| except ImportError: | ||
| CVCUDA_AVAILABLE = False | ||
| CUDA_AVAILABLE = torch.cuda.is_available() | ||
|
|
||
|
|
||
| # turns all warnings into errors for this module | ||
| pytestmark = [pytest.mark.filterwarnings("error")] | ||
|
|
@@ -6733,6 +6742,125 @@ def test_functional_error(self): | |
| F.pil_to_tensor(object()) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") | ||
| @pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA") | ||
| class TestToCVCUDATensor: | ||
| """Tests for to_cvcuda_tensor function following patterns from TestToPil""" | ||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) | ||
| def test_1_channel_to_cvcuda_tensor(self, dtype): | ||
| # Create tensor on CPU first, then move to CUDA to avoid CUDA context issues | ||
| if dtype in (torch.uint8, torch.uint16): | ||
| img_data = torch.randint(0, 256, (1, 4, 4), dtype=dtype) | ||
| else: | ||
| img_data = torch.rand(1, 4, 4, dtype=dtype) | ||
| img_data = img_data.cuda() | ||
| cvcuda_img = F.to_cvcuda_tensor(img_data) | ||
| assert cvcuda_img is not None | ||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) | ||
| def test_3_channel_to_cvcuda_tensor(self, dtype): | ||
| # Create tensor on CPU first, then move to CUDA to avoid CUDA context issues | ||
| if dtype in (torch.uint8, torch.uint16): | ||
| img_data = torch.randint(0, 256, (3, 4, 4), dtype=dtype) | ||
| else: | ||
| img_data = torch.rand(3, 4, 4, dtype=dtype) | ||
| img_data = img_data.cuda() | ||
| cvcuda_img = F.to_cvcuda_tensor(img_data) | ||
| assert cvcuda_img is not None | ||
|
|
||
| def test_invalid_input_type(self): | ||
| with pytest.raises(TypeError, match=r"pic should be `torch.Tensor`"): | ||
| F.to_cvcuda_tensor("invalid_input") | ||
|
|
||
| def test_invalid_dimensions(self): | ||
| # Test 1D array (too few dimensions) | ||
| # Create tensor on CPU first, then move to CUDA to avoid CUDA context issues | ||
| with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"): | ||
| img_data = torch.randint(0, 256, (4,), dtype=torch.uint8) | ||
| img_data = img_data.cuda() | ||
| F.to_cvcuda_tensor(img_data) | ||
|
|
||
| # Test 2D array (no longer supported) | ||
AntoineSimoulin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"): | ||
| img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8) | ||
| img_data = img_data.cuda() | ||
| F.to_cvcuda_tensor(img_data) | ||
|
|
||
| # Test 5D array (too many dimensions) | ||
| with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"): | ||
| img_data = torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8) | ||
| img_data = img_data.cuda() | ||
| F.to_cvcuda_tensor(img_data) | ||
AntoineSimoulin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @pytest.mark.parametrize("num_channels", [1, 3]) | ||
|
||
| @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) | ||
| def test_round_trip(self, num_channels, dtype): | ||
| # Setup: Create a tensor in CHW format (PyTorch standard) | ||
| # Create tensor on CPU first, then move to CUDA to avoid CUDA context issues | ||
| if dtype in (torch.uint8, torch.uint16): | ||
| original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype) | ||
| else: | ||
| original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype) | ||
| original_tensor = original_tensor.cuda() | ||
|
|
||
| # Execute: Convert to CV-CUDA and back to tensor | ||
| # CHW -> (to_cvcuda_tensor) -> CV-CUDA NHWC -> (cvcuda_to_tensor) -> NCHW | ||
| cvcuda_tensor = F.to_cvcuda_tensor(original_tensor) | ||
| result_tensor = F.cvcuda_to_tensor(cvcuda_tensor) | ||
|
|
||
| # Remove batch dimension that was added during conversion since original was unbatched | ||
| result_tensor = result_tensor.squeeze(0) | ||
|
|
||
| # Assert: The round-trip conversion preserves the original tensor exactly | ||
| torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0) | ||
|
|
||
| @pytest.mark.parametrize("num_channels", [1, 3]) | ||
| @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) | ||
| @pytest.mark.parametrize("batch_size", [1, 2, 4]) | ||
| def test_round_trip_batched(self, num_channels, dtype, batch_size): | ||
| # Setup: Create a batched tensor in NCHW format | ||
| # Create tensor on CPU first, then move to CUDA to avoid CUDA context issues | ||
| if dtype in (torch.uint8, torch.uint16): | ||
| original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype) | ||
| else: | ||
| original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype) | ||
| original_tensor = original_tensor.cuda() | ||
|
|
||
| # Execute: Convert to CV-CUDA and back to tensor | ||
| # NCHW -> (to_cvcuda_tensor) -> CV-CUDA NHWC -> (cvcuda_to_tensor) -> NCHW | ||
| cvcuda_tensor = F.to_cvcuda_tensor(original_tensor) | ||
| result_tensor = F.cvcuda_to_tensor(cvcuda_tensor) | ||
|
|
||
| # Assert: The round-trip conversion preserves the original batched tensor exactly | ||
| torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0) | ||
| # Also verify batch size is preserved | ||
| assert result_tensor.shape[0] == batch_size | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") | ||
| @pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA") | ||
| class TestCVDUDAToTensor: | ||
| @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) | ||
| @pytest.mark.parametrize( | ||
| "fn", | ||
| [F.cvcuda_to_tensor, transform_cls_to_functional(transforms.CVCUDAToTensor)], | ||
| ) | ||
| def test_functional_and_transform(self, color_space, fn): | ||
| input = make_image_cvcuda(color_space=color_space) | ||
|
|
||
| output = fn(input) | ||
|
|
||
| assert isinstance(output, torch.Tensor) | ||
| # Convert input to tensor to compare sizes | ||
| input_tensor = F.cvcuda_to_tensor(input) | ||
| assert F.get_size(output) == F.get_size(input_tensor) | ||
|
|
||
| def test_functional_error(self): | ||
| with pytest.raises(TypeError, match="cvcuda_img should be `cvcuda.Tensor`"): | ||
| F.cvcuda_to_tensor(object()) | ||
|
|
||
|
|
||
| class TestLambda: | ||
| @pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0]) | ||
| @pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,36 @@ | ||
| from typing import Union | ||
| from typing import TYPE_CHECKING, Union | ||
|
|
||
| import numpy as np | ||
| import PIL.Image | ||
| import torch | ||
| from torchvision import tv_tensors | ||
| from torchvision.transforms import functional as _F | ||
| from torchvision.utils import _log_api_usage_once | ||
|
|
||
| if TYPE_CHECKING: | ||
| import cvcuda # type: ignore[import-not-found] | ||
|
|
||
|
|
||
| def _import_cvcuda_modules(): | ||
| """Import CV-CUDA modules with informative error message if not installed. | ||
|
|
||
| Returns: | ||
| cvcuda module. | ||
|
|
||
| Raises: | ||
| RuntimeError: If CV-CUDA is not installed. | ||
| """ | ||
| try: | ||
| import cvcuda # type: ignore[import-not-found] | ||
|
|
||
| return cvcuda | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "CV-CUDA is required but not installed. " | ||
| "Please install it following the instructions at " | ||
| "https://github.com/CVCUDA/CV-CUDA or via pip: " | ||
| "`pip install cvcuda-cu12` (for CUDA 12) or `pip install cvcuda-cu11` (for CUDA 11)." | ||
| ) from e | ||
|
|
||
|
|
||
| @torch.jit.unused | ||
|
|
@@ -25,3 +51,82 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso | |
|
|
||
| to_pil_image = _F.to_pil_image | ||
| pil_to_tensor = _F.pil_to_tensor | ||
|
|
||
|
|
||
| @torch.jit.unused | ||
| def to_cvcuda_tensor(pic) -> "cvcuda.Tensor": | ||
| """Convert a torch.Tensor to cvcuda.Tensor. This function does not support torchscript. | ||
|
|
||
| See :class:`~torchvision.transforms.v2.ToCVCUDATensor` for more details. | ||
|
|
||
| Args: | ||
| pic (torch.Tensor): Image to be converted to cvcuda.Tensor. | ||
| Tensor can be in CHW format (unbatched) or NCHW format (batched). | ||
| Only 1-channel and 3-channel images are supported. | ||
|
||
|
|
||
| Returns: | ||
| cvcuda.Tensor: Image converted to cvcuda.Tensor with NHWC layout. | ||
| """ | ||
| cvcuda = _import_cvcuda_modules() | ||
|
|
||
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||
| _log_api_usage_once(to_cvcuda_tensor) | ||
|
|
||
| # Validate input type | ||
| if not isinstance(pic, torch.Tensor): | ||
| raise TypeError(f"pic should be `torch.Tensor`. Got {type(pic)}.") | ||
|
|
||
| # Validate dimensions - only support 3D (CHW) or 4D (NCHW) | ||
| if pic.ndim == 3: | ||
| # Add fake batch dimension to make it 4D | ||
| img_tensor = pic.unsqueeze(0) | ||
| elif pic.ndim == 4: | ||
| img_tensor = pic | ||
| else: | ||
| raise ValueError(f"pic should be 3 or 4 dimensional. Got {pic.ndim} dimensions.") | ||
|
|
||
| # Convert NCHW -> NHWC | ||
| img_tensor = img_tensor.permute(0, 2, 3, 1) | ||
|
|
||
| # Convert to CV-CUDA tensor with NHWC layout | ||
| return cvcuda.as_tensor(img_tensor.cuda().contiguous(), cvcuda.TensorLayout.NHWC) | ||
|
|
||
|
|
||
| @torch.jit.unused | ||
| def cvcuda_to_tensor(cvcuda_img: "cvcuda.Tensor") -> torch.Tensor: | ||
| """Convert a cvcuda.Tensor to a PyTorch tensor. This function does not support torchscript. | ||
|
|
||
| Args: | ||
| cvcuda_img (cvcuda.Tensor): cvcuda.Tensor to be converted to PyTorch tensor. | ||
| Expected to be in NHWC or NHW layout (batched images only). | ||
|
|
||
| Returns: | ||
| torch.Tensor: Converted image in NCHW format (batched). | ||
| """ | ||
| cvcuda = _import_cvcuda_modules() | ||
|
|
||
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||
| _log_api_usage_once(cvcuda_to_tensor) | ||
|
|
||
| # Validate input type | ||
| if not isinstance(cvcuda_img, cvcuda.Tensor): | ||
| raise TypeError(f"cvcuda_img should be `cvcuda.Tensor`. Got {type(cvcuda_img)}.") | ||
|
|
||
| # Convert CV-CUDA Tensor to PyTorch tensor via CUDA array interface | ||
| # CV-CUDA tensors expose __cuda_array_interface__ which PyTorch can consume directly | ||
| cuda_tensor = torch.as_tensor(cvcuda_img.cuda(), device="cuda") | ||
|
|
||
| # Only support 4D (NHWC) or 3D (NHW) batched tensors | ||
| # CV-CUDA stores images in NHWC (batched multi-channel) or NHW (batched single-channel) format | ||
| if cuda_tensor.ndim == 4: | ||
| # Batched multi-channel image in NHWC format | ||
| # Convert NHWC -> NCHW | ||
| img = cuda_tensor.permute(0, 3, 1, 2) | ||
| elif cuda_tensor.ndim == 3: | ||
AntoineSimoulin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Batched single-channel image in NHW format | ||
| # Convert NHW -> NCHW by adding channel dimension | ||
| img = cuda_tensor.unsqueeze(1) | ||
| else: | ||
| raise ValueError(f"Image should be 3 or 4 dimensional. Got {cuda_tensor.ndim} dimensions.") | ||
|
|
||
| return img | ||
Uh oh!
There was an error while loading. Please reload this page.