Skip to content

Commit 1d05f58

Browse files
fix 3 dimensions check
1 parent ca7b9a1 commit 1d05f58

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

test/test_transforms_v2.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6768,17 +6768,31 @@ def test_invalid_input_type(self):
67686768
F.to_cvcuda_tensor("invalid_input")
67696769

67706770
def test_invalid_dimensions(self):
6771-
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6771+
with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
6772+
img_data = torch.randint(
6773+
0,
6774+
256,
6775+
(
6776+
3,
6777+
12,
6778+
34,
6779+
),
6780+
dtype=torch.uint8,
6781+
)
6782+
img_data = img_data.cuda()
6783+
F.to_cvcuda_tensor(img_data)
6784+
6785+
with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
67726786
img_data = torch.randint(0, 256, (4,), dtype=torch.uint8)
67736787
img_data = img_data.cuda()
67746788
F.to_cvcuda_tensor(img_data)
67756789

6776-
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6790+
with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
67776791
img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8)
67786792
img_data = img_data.cuda()
67796793
F.to_cvcuda_tensor(img_data)
67806794

6781-
with pytest.raises(ValueError, match=r"pic should be 3 or 4 dimensional"):
6795+
with pytest.raises(ValueError, match=r"pic should be 4 dimensional"):
67826796
img_data = torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8)
67836797
img_data = img_data.cuda()
67846798
F.to_cvcuda_tensor(img_data)

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def to_cvcuda_tensor(inpt: torch.Tensor) -> "cvcuda.Tensor":
4141
_log_api_usage_once(to_cvcuda_tensor)
4242
if not isinstance(inpt, torch.Tensor):
4343
raise TypeError(f"inpt should be `torch.Tensor`. Got {type(inpt)}.")
44-
if inpt.ndim not in [3, 4]:
45-
raise ValueError(f"pic should be 3 or 4 dimensional. Got {inpt.ndim} dimensions.")
44+
if inpt.ndim != 4:
45+
raise ValueError(f"pic should be 4 dimensional. Got {inpt.ndim} dimensions.")
4646
# Convert to NHWC as CVCUDA transforms do not support NCHW
4747
inpt = inpt.permute(0, 2, 3, 1)
4848
return cvcuda.as_tensor(inpt.cuda().contiguous(), cvcuda.TensorLayout.NHWC)
@@ -57,8 +57,8 @@ def cvcuda_to_tensor(cvcuda_img: "cvcuda.Tensor") -> torch.Tensor:
5757
if not isinstance(cvcuda_img, cvcuda.Tensor):
5858
raise TypeError(f"cvcuda_img should be `cvcuda.Tensor`. Got {type(cvcuda_img)}.")
5959
cuda_tensor = torch.as_tensor(cvcuda_img.cuda(), device="cuda")
60-
if cvcuda_img.ndim not in [3, 4]:
61-
raise ValueError(f"Image should be 3 or 4 dimensional. Got {cuda_tensor.ndim} dimensions.")
60+
if cvcuda_img.ndim != 4:
61+
raise ValueError(f"Image should be 4 dimensional. Got {cuda_tensor.ndim} dimensions.")
6262
# Convert to NCHW layout from CVCUDA default NHWC
6363
img = cuda_tensor.permute(0, 3, 1, 2)
6464
return img

0 commit comments

Comments
 (0)