@@ -102,8 +102,11 @@ def pil_pipeline(image: Image.Image, target_size: int) -> torch.Tensor:
102102
103103
104104def cudacv_pipeline (image : torch .Tensor , target_size : int ) -> torch .Tensor :
105- # Permute from NCHW to NHWC and ensure contiguity
106- image = image .permute (0 , 2 , 3 , 1 ).contiguous ()
105+ channel_first = image .shape [- 1 ] != 3
106+ if channel_first :
107+ image = image .permute (0 , 2 , 3 , 1 ).contiguous ()
108+ # image = cvcuda.as_tensor(image, nvcv.TensorLayout.NCHW)
109+ # image = cvcuda.reformat(image, nvcv.TensorLayout.NHWC)
107110 image = cvcuda .as_tensor (image , nvcv .TensorLayout .NHWC )
108111 img : cvcuda .Tensor = cvcuda .resize (
109112 image ,
@@ -119,7 +122,10 @@ def cudacv_pipeline(image: torch.Tensor, target_size: int) -> torch.Tensor:
119122 img : cvcuda .Tensor = cvcuda .normalize (
120123 img , NORM_MEAN_CUDA_CV , NORM_STD_CUDA_CV
121124 )
122- return torch .as_tensor (img .cuda ())
125+ out = torch .as_tensor (img .cuda ())
126+ if channel_first :
127+ out = out .permute (0 , 3 , 1 , 2 )
128+ return out
123129
124130
125131def albumentations_pipeline (image : np .ndarray , target_size : int ) -> torch .Tensor :
0 commit comments