Skip to content

Commit 923b2be

Browse files
Do tensor conversion in benchmark function
1 parent c37ed06 commit 923b2be

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

benchmarks/benchmark_transforms.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,13 @@ def pil_pipeline(image: Image.Image, target_size: int) -> torch.Tensor:
101101
return img
102102

103103

104-
def cudacv_pipeline(image: cvcuda.Tensor, target_size: int) -> torch.Tensor:
105-
104+
def cudacv_pipeline(image: torch.Tensor, target_size: int) -> torch.Tensor:
105+
# Permute from NCHW to NHWC and ensure contiguity
106+
image = image.permute(0, 2, 3, 1).contiguous()
107+
image = cvcuda.as_tensor(image, nvcv.TensorLayout.NHWC)
106108
img: cvcuda.Tensor = cvcuda.resize(
107109
image,
108-
(image.shape[0], target_size, target_size, image.shape[-1]), # N, H, W, C
110+
(image.shape[0], target_size, target_size, image.shape[-1]),
109111
interp=cvcuda.Interp.LINEAR,
110112
)
111113
img: cvcuda.Tensor = cvcuda.convertto(
@@ -234,9 +236,6 @@ def generate_test_images():
234236
elif backend == "cudacv":
235237
if images.ndim == 3: # no batch dimension
236238
images = images.unsqueeze(0)
237-
# Permute from NCHW to NHWC and ensure contiguity
238-
images = images.permute(0, 2, 3, 1).contiguous()
239-
images = cvcuda.as_tensor(images, nvcv.TensorLayout.NHWC)
240239
elif backend == "albumentations":
241240
if args.batch_size > 1:
242241
# TODO is that true????

0 commit comments

Comments
 (0)