Skip to content

Commit 00d4484

Browse files
add tv_tensors.Image in tests
1 parent 358cf81 commit 00d4484

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

test/test_transforms_v2.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6746,6 +6746,7 @@ def test_functional_error(self):
67466746
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
67476747
@needs_cuda
67486748
class TestToCVCUDATensor:
6749+
@pytest.mark.parametrize("image_type", (torch.Tensor, tv_tensors.Image))
67496750
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
67506751
@pytest.mark.parametrize("device", cpu_and_cuda())
67516752
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@@ -6754,30 +6755,24 @@ class TestToCVCUDATensor:
67546755
"fn",
67556756
[F.to_cvcuda_tensor, transform_cls_to_functional(transforms.ToCVCUDATensor)],
67566757
)
6757-
def test_functional_and_transform(self, dtype, device, color_space, batch_dims, fn):
6758-
input = make_image_tensor(dtype=dtype, device=device, color_space=color_space, batch_dims=batch_dims)
6759-
output = fn(input)
6758+
def test_functional_and_transform(self, image_type, dtype, device, color_space, batch_dims, fn):
6759+
image = make_image(dtype=dtype, device=device, color_space=color_space, batch_dims=batch_dims)
6760+
if image_type is torch.Tensor:
6761+
image = image.as_subclass(torch.Tensor)
6762+
assert is_pure_tensor(image)
6763+
output = fn(image)
67606764

67616765
assert isinstance(output, cvcuda.Tensor)
6762-
assert F.get_size(output) == F.get_size(input)
6766+
assert F.get_size(output) == F.get_size(image)
67636767
assert output is not None
67646768

67656769
def test_invalid_input_type(self):
6766-
with pytest.raises(TypeError, match=r"inpt should be `torch.Tensor`"):
6770+
with pytest.raises(TypeError, match=r"inpt should be ``torch.Tensor``"):
67676771
F.to_cvcuda_tensor("invalid_input")
67686772

67696773
def test_invalid_dimensions(self):
67706774
with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
6771-
img_data = torch.randint(
6772-
0,
6773-
256,
6774-
(
6775-
3,
6776-
12,
6777-
34,
6778-
),
6779-
dtype=torch.uint8,
6780-
)
6775+
img_data = torch.randint(0, 256, (3, 1, 3), dtype=torch.uint8)
67816776
img_data = img_data.cuda()
67826777
F.to_cvcuda_tensor(img_data)
67836778

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def to_cvcuda_tensor(inpt: torch.Tensor) -> "cvcuda.Tensor":
3939
cvcuda = _import_cvcuda()
4040
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
4141
_log_api_usage_once(to_cvcuda_tensor)
42-
if not isinstance(inpt, torch.Tensor):
43-
raise TypeError(f"inpt should be ``torch.Tensor``. Got {type(inpt)}.")
42+
if not isinstance(inpt, (torch.Tensor, tv_tensors.Image)):
43+
raise TypeError(f"inpt should be ``torch.Tensor`` or ``tv_tensors.Image``. Got {type(inpt)}.")
4444
if inpt.ndim != 4:
4545
raise ValueError(f"pic should be 4 dimensional. Got {inpt.ndim} dimensions.")
4646
# Convert to NHWC as CVCUDA transforms do not support NCHW

0 commit comments

Comments
 (0)