Skip to content

Commit 32a5bf9

Browse files
fix return for _validate_cvcuda_dtype
1 parent 1dcae5c commit 32a5bf9

File tree

1 file changed

+8
-27
lines changed

1 file changed

+8
-27
lines changed

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -53,51 +53,32 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso
5353
pil_to_tensor = _F.pil_to_tensor
5454

5555

56-
def _infer_cvcuda_format(img_tensor: torch.Tensor):
57-
"""Infer CV-CUDA format from tensor shape and dtype.
56+
def _validate_cvcuda_dtype(img_tensor: torch.Tensor) -> None:
57+
"""Validate that tensor dtype and channel count are supported by CV-CUDA.
5858
5959
Args:
6060
img_tensor: Tensor with shape (H, W, C) where C is number of channels.
6161
62-
Returns:
63-
tuple: (cvcuda_format, processed_tensor) where processed_tensor may have reduced dimensions
64-
for single channel images.
65-
6662
Raises:
6763
TypeError: If dtype is not supported for the given number of channels.
6864
ValueError: If number of channels is not 1 or 3.
6965
"""
70-
cvcuda = _import_cvcuda_modules()
71-
7266
num_channels = img_tensor.shape[2]
7367
dtype = img_tensor.dtype
7468

7569
# Handle single channel images
7670
if num_channels == 1:
77-
if dtype == torch.uint8:
78-
return cvcuda.Format.U8, img_tensor
79-
elif dtype == torch.int16:
80-
return cvcuda.Format.S16, img_tensor
81-
elif dtype == torch.int32:
82-
return cvcuda.Format.S32, img_tensor
83-
elif dtype == torch.float32:
84-
return cvcuda.Format.F32, img_tensor
85-
elif dtype == torch.float64:
86-
return cvcuda.Format.F64, img_tensor
87-
else:
71+
if dtype not in (torch.uint8, torch.int16, torch.int32, torch.float32, torch.float64):
8872
raise TypeError(f"Unsupported dtype {dtype} for single channel image")
8973

9074
# Handle 3 channel images (defaults to RGB)
9175
elif num_channels == 3:
92-
if dtype == torch.uint8:
93-
return cvcuda.Format.RGB8, img_tensor
94-
elif dtype == torch.float32:
95-
return cvcuda.Format.RGBf32, img_tensor
96-
else:
76+
if dtype not in (torch.uint8, torch.float32):
9777
# Note: CV-CUDA does not support float64 for RGB images (only F64 for single-channel)
9878
raise TypeError(f"Unsupported dtype {dtype} for 3-channel image")
9979

100-
raise ValueError(f"Only 1 and 3 channel images are supported. Got {num_channels} channels.")
80+
else:
81+
raise ValueError(f"Only 1 and 3 channel images are supported. Got {num_channels} channels.")
10182

10283

10384
@torch.jit.unused
@@ -136,9 +117,9 @@ def to_cvcuda_tensor(pic) -> "cvcuda.Tensor":
136117
# Convert NCHW -> NHWC
137118
img_tensor = img_tensor.permute(0, 2, 3, 1)
138119

139-
# Infer format from the first image
120+
# Validate dtype and channel count from the first image
140121
sample_img = img_tensor[0]
141-
_infer_cvcuda_format(sample_img)
122+
_validate_cvcuda_dtype(sample_img)
142123

143124
# Convert to CV-CUDA tensor with NHWC layout
144125
return cvcuda.as_tensor(img_tensor.cuda().contiguous(), cvcuda.TensorLayout.NHWC)

0 commit comments

Comments
 (0)