@@ -101,11 +101,13 @@ def pil_pipeline(image: Image.Image, target_size: int) -> torch.Tensor:
101101 return img
102102
103103
104- def cudacv_pipeline (image : cvcuda .Tensor , target_size : int ) -> torch .Tensor :
105-
104+ def cudacv_pipeline (image : torch .Tensor , target_size : int ) -> torch .Tensor :
105+ # Permute from NCHW to NHWC and ensure contiguity
106+ image = image .permute (0 , 2 , 3 , 1 ).contiguous ()
107+ image = cvcuda .as_tensor (image , nvcv .TensorLayout .NHWC )
106108 img : cvcuda .Tensor = cvcuda .resize (
107109 image ,
108- (image .shape [0 ], target_size , target_size , image .shape [- 1 ]), # N, H, W, C
110+ (image .shape [0 ], target_size , target_size , image .shape [- 1 ]),
109111 interp = cvcuda .Interp .LINEAR ,
110112 )
111113 img : cvcuda .Tensor = cvcuda .convertto (
@@ -234,9 +236,6 @@ def generate_test_images():
234236 elif backend == "cudacv" :
235237 if images .ndim == 3 : # no batch dimension
236238 images = images .unsqueeze (0 )
237- # Permute from NCHW to NHWC and ensure contiguity
238- images = images .permute (0 , 2 , 3 , 1 ).contiguous ()
239- images = cvcuda .as_tensor (images , nvcv .TensorLayout .NHWC )
240239 elif backend == "albumentations" :
241240 if args .batch_size > 1 :
242241 # TODO is that true????
0 commit comments