Skip to content

Commit 9f1655a

Browse files
Add logic for channel first and last in cuda-cv
1 parent 923b2be commit 9f1655a

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

benchmarks/benchmark_transforms.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,11 @@ def pil_pipeline(image: Image.Image, target_size: int) -> torch.Tensor:
102102

103103

104104
def cudacv_pipeline(image: torch.Tensor, target_size: int) -> torch.Tensor:
105-
# Permute from NCHW to NHWC and ensure contiguity
106-
image = image.permute(0, 2, 3, 1).contiguous()
105+
channel_first = image.shape[-1] != 3
106+
if channel_first:
107+
image = image.permute(0, 2, 3, 1).contiguous()
108+
# image = cvcuda.as_tensor(image, nvcv.TensorLayout.NCHW)
109+
# image = cvcuda.reformat(image, nvcv.TensorLayout.NHWC)
107110
image = cvcuda.as_tensor(image, nvcv.TensorLayout.NHWC)
108111
img: cvcuda.Tensor = cvcuda.resize(
109112
image,
@@ -119,7 +122,10 @@ def cudacv_pipeline(image: torch.Tensor, target_size: int) -> torch.Tensor:
119122
img: cvcuda.Tensor = cvcuda.normalize(
120123
img, NORM_MEAN_CUDA_CV, NORM_STD_CUDA_CV
121124
)
122-
return torch.as_tensor(img.cuda())
125+
out = torch.as_tensor(img.cuda())
126+
if channel_first:
127+
out = out.permute(0, 3, 1, 2)
128+
return out
123129

124130

125131
def albumentations_pipeline(image: np.ndarray, target_size: int) -> torch.Tensor:

0 commit comments

Comments
 (0)