@@ -53,51 +53,32 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso
5353pil_to_tensor = _F .pil_to_tensor
5454
5555
56- def _infer_cvcuda_format (img_tensor : torch .Tensor ):
57- """Infer CV-CUDA format from tensor shape and dtype .
56+ def _validate_cvcuda_dtype (img_tensor : torch .Tensor ) -> None :
57+ """Validate that tensor dtype and channel count are supported by CV-CUDA .
5858
5959 Args:
6060 img_tensor: Tensor with shape (H, W, C) where C is number of channels.
6161
62- Returns:
63- tuple: (cvcuda_format, processed_tensor) where processed_tensor may have reduced dimensions
64- for single channel images.
65-
6662 Raises:
6763 TypeError: If dtype is not supported for the given number of channels.
6864 ValueError: If number of channels is not 1 or 3.
6965 """
70- cvcuda = _import_cvcuda_modules ()
71-
7266 num_channels = img_tensor .shape [2 ]
7367 dtype = img_tensor .dtype
7468
7569 # Handle single channel images
7670 if num_channels == 1 :
77- if dtype == torch .uint8 :
78- return cvcuda .Format .U8 , img_tensor
79- elif dtype == torch .int16 :
80- return cvcuda .Format .S16 , img_tensor
81- elif dtype == torch .int32 :
82- return cvcuda .Format .S32 , img_tensor
83- elif dtype == torch .float32 :
84- return cvcuda .Format .F32 , img_tensor
85- elif dtype == torch .float64 :
86- return cvcuda .Format .F64 , img_tensor
87- else :
71+ if dtype not in (torch .uint8 , torch .int16 , torch .int32 , torch .float32 , torch .float64 ):
8872 raise TypeError (f"Unsupported dtype { dtype } for single channel image" )
8973
9074 # Handle 3 channel images (defaults to RGB)
9175 elif num_channels == 3 :
92- if dtype == torch .uint8 :
93- return cvcuda .Format .RGB8 , img_tensor
94- elif dtype == torch .float32 :
95- return cvcuda .Format .RGBf32 , img_tensor
96- else :
76+ if dtype not in (torch .uint8 , torch .float32 ):
9777 # Note: CV-CUDA does not support float64 for RGB images (only F64 for single-channel)
9878 raise TypeError (f"Unsupported dtype { dtype } for 3-channel image" )
9979
100- raise ValueError (f"Only 1 and 3 channel images are supported. Got { num_channels } channels." )
80+ else :
81+ raise ValueError (f"Only 1 and 3 channel images are supported. Got { num_channels } channels." )
10182
10283
10384@torch .jit .unused
@@ -136,9 +117,9 @@ def to_cvcuda_tensor(pic) -> "cvcuda.Tensor":
136117 # Convert NCHW -> NHWC
137118 img_tensor = img_tensor .permute (0 , 2 , 3 , 1 )
138119
139- # Infer format from the first image
120+ # Validate dtype and channel count from the first image
140121 sample_img = img_tensor [0 ]
141- _infer_cvcuda_format (sample_img )
122+ _validate_cvcuda_dtype (sample_img )
142123
143124 # Convert to CV-CUDA tensor with NHWC layout
144125 return cvcuda .as_tensor (img_tensor .cuda ().contiguous (), cvcuda .TensorLayout .NHWC )
0 commit comments