Skip to content

Commit f0cd706

Browse files
Lint
1 parent 4cff7b1 commit f0cd706

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

test/test_cvcuda.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
from torchvision import _is_cvcuda_available
2-
import torch
31
import pytest
42
import torch
3+
from torchvision import _is_cvcuda_available
54
from torchvision.transforms.v2 import functional as F
5+
66
CVCUDA_AVAILABLE = _is_cvcuda_available()
77
CUDA_AVAILABLE = torch.cuda.is_available()
88

99

1010
if CVCUDA_AVAILABLE:
1111
import nvcv
1212

13+
1314
@pytest.mark.skipif(CVCUDA_AVAILABLE is False, reason="test requires CVCUDA")
1415
@pytest.mark.skipif(CUDA_AVAILABLE is False, reason="test requires CUDA")
1516
class TestToNvcvTensor:
@@ -156,16 +157,24 @@ def transform_cls_to_functional(get_transform_cls):
156157
def wrapper(inpt):
157158
transform_cls = get_transform_cls()
158159
return transform_cls()(inpt)
160+
159161
return wrapper
160162

161163

162164
@pytest.mark.skipif(CVCUDA_AVAILABLE is False, reason="test requires CVCUDA")
163165
@pytest.mark.skipif(CUDA_AVAILABLE is False, reason="test requires CUDA")
164166
class TestNVCVToTensor:
165-
166167
@pytest.mark.parametrize("num_channels", [1, 3, 4])
167168
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
168-
@pytest.mark.parametrize("fn", [F.nvcv_to_tensor, transform_cls_to_functional(lambda: __import__('torchvision.transforms.v2', fromlist=['NVCVToTensor']).NVCVToTensor)])
169+
@pytest.mark.parametrize(
170+
"fn",
171+
[
172+
F.nvcv_to_tensor,
173+
transform_cls_to_functional(
174+
lambda: __import__("torchvision.transforms.v2", fromlist=["NVCVToTensor"]).NVCVToTensor
175+
),
176+
],
177+
)
169178
def test_functional_and_transform(self, num_channels, dtype, fn):
170179
input = make_nvcv_image(num_channels=num_channels, dtype=dtype)
171180
output = fn(input)

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,4 @@
6262
from torchvision import _is_cvcuda_available
6363

6464
if _is_cvcuda_available():
65-
from ._cvcuda import ToNVCVTensor, NVCVToTensor
65+
from ._cvcuda import NVCVToTensor, ToNVCVTensor

torchvision/transforms/v2/_cvcuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from torchvision.utils import _log_api_usage_once
21
from torchvision.transforms.v2 import functional as F
2+
from torchvision.utils import _log_api_usage_once
33

44

55
class ToNVCVTensor:

torchvision/transforms/v2/functional/_cvcuda.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import numpy as np
21
import warnings
2+
3+
import numpy as np
4+
import nvcv
35
import torch
46
from torchvision.utils import _log_api_usage_once
5-
import nvcv
67

78

89
def _infer_nvcv_format(img_tensor: torch.Tensor):
@@ -100,8 +101,7 @@ def _validate_nvcv_format(format, num_channels: int) -> None:
100101
elif num_channels == 2:
101102
if format not in two_channel_formats:
102103
warnings.warn(
103-
f"Format {format} may not be appropriate for 2-channel image. "
104-
f"Common 2-channel formats: _2F32"
104+
f"Format {format} may not be appropriate for 2-channel image. " f"Common 2-channel formats: _2F32"
105105
)
106106
elif num_channels == 3:
107107
if format not in three_channel_formats:
@@ -159,7 +159,7 @@ def to_nvcv_tensor(pic, format=None):
159159
# Ensure image has channel dimension for unbatched case
160160
if img_tensor.ndim == 2:
161161
img_tensor = img_tensor.unsqueeze(2) # H W -> H W C
162-
162+
163163
# Validate dimensions
164164
if img_tensor.ndim not in (3, 4):
165165
raise ValueError(f"pic should be 2/3/4 dimensional. Got {img_tensor.ndim} dimensions.")
@@ -201,6 +201,7 @@ def to_nvcv_tensor(pic, format=None):
201201

202202
# Convert to NVCV tensor with NHWC layout
203203
import cvcuda
204+
204205
return cvcuda.as_tensor(img_tensor.contiguous(), nvcv.TensorLayout.NHWC)
205206

206207

@@ -222,7 +223,7 @@ def nvcv_to_tensor(nvcv_img):
222223

223224
# Convert NVCV Tensor to PyTorch tensor via CUDA array interface
224225
# NVCV tensors expose __cuda_array_interface__ which PyTorch can consume directly
225-
cuda_tensor = torch.as_tensor(nvcv_img.cuda(), device='cuda')
226+
cuda_tensor = torch.as_tensor(nvcv_img.cuda(), device="cuda")
226227

227228
# Handle different dimensionalities
228229
if cuda_tensor.ndim == 4:

0 commit comments

Comments
 (0)