Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
44db71c
implement additional cvcuda infra for all branches to avoid duplicate…
justincdavis Nov 25, 2025
e3dd700
update make_image_cvcuda to have default batch dim
justincdavis Nov 25, 2025
c035df1
add stanardized setup to main for easier updating of PRs and branches
justincdavis Dec 2, 2025
98d7dfb
update is_cvcuda_tensor
justincdavis Dec 2, 2025
ddc116d
add cvcuda to pil compatible to transforms by default
justincdavis Dec 2, 2025
e51dc7e
remove cvcuda from transform class
justincdavis Dec 2, 2025
e14e210
merge with main
justincdavis Dec 4, 2025
4939355
resolve more formatting naming
justincdavis Dec 4, 2025
fbea584
update is cvcuda tensor impl
justincdavis Dec 4, 2025
ffe7a14
initial cvcuda crop implementation, only minimal tests so far
justincdavis Nov 18, 2025
9133c3d
add padding to centercrop and if needed to crop
justincdavis Nov 18, 2025
878d2ae
test padding for crop_cvcuda, add functional test
justincdavis Nov 18, 2025
2219ee5
center_crop passes functional equiv
justincdavis Nov 18, 2025
ed2bd35
fix: crop testing, adhere to conventions
justincdavis Nov 20, 2025
3582c58
Fix: update center crop
justincdavis Nov 20, 2025
18922e3
handle some comments from other prs review
justincdavis Nov 24, 2025
37a91e0
simplify and improve crop testing for cvcuda
justincdavis Nov 26, 2025
9b721ef
simplify test for center crop cvcuda
justincdavis Nov 26, 2025
6a0035d
begin work on finalizing the crop PR to include five and ten crop, ad…
justincdavis Dec 1, 2025
e287fc1
update to include five ten crop and resized crop, use placeholder tra…
justincdavis Dec 2, 2025
540551a
update crop to new main standards
justincdavis Dec 4, 2025
62877ca
reduce diff
justincdavis Dec 4, 2025
c2964f8
check input type on kernel for signature test
justincdavis Dec 4, 2025
68c1827
start minimizing diff, add dummy for pad so all crop tests are assert…
justincdavis Dec 12, 2025
fff50a5
more diff minimization
justincdavis Dec 12, 2025
adb11ea
more minimization from standard infra
justincdavis Dec 12, 2025
9ceb844
drop dummy funcs
justincdavis Dec 12, 2025
0eb8b6f
add cvcuda tensor to query_size
justincdavis Dec 12, 2025
47943c0
Merge remote-tracking branch 'upstream/main' into feat/crop_cvcuda
justincdavis Dec 18, 2025
c794596
update marks
justincdavis Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 135 additions & 16 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -42,7 +43,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -3497,6 +3497,7 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_functional(self, make_input):
Expand All @@ -3512,16 +3513,34 @@ def test_functional(self, make_input):
(F.crop_mask, tv_tensors.Mask),
(F.crop_video, tv_tensors.Video),
(F.crop_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._crop_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._crop_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
def test_functional_image_correctness(self, kwargs):
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_functional_image_correctness(self, kwargs, make_input):
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")

actual = F.crop(image, **kwargs)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs))

assert_equal(actual, expected)
Expand All @@ -3540,6 +3559,7 @@ def test_functional_image_correctness(self, kwargs):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_transform(self, param, value, make_input):
Expand Down Expand Up @@ -3606,7 +3626,14 @@ def test_transform_pad_if_needed(self):
padding_mode=["constant", "edge", "reflect", "symmetric"],
)
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, param, value, seed):
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_transform_image_correctness(self, param, value, seed, make_input):
kwargs = {param: value}
if param != "size":
# 1. size is required
Expand All @@ -3617,13 +3644,17 @@ def test_transform_image_correctness(self, param, value, seed):

transform = transforms.RandomCrop(pad_if_needed=True, **kwargs)

image = make_image(self.INPUT_SIZE)
image = make_input(self.INPUT_SIZE)

with freeze_rng_state():
torch.manual_seed(seed)
actual = transform(image)

torch.manual_seed(seed)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(transform(F.to_pil_image(image)))

assert_equal(actual, expected)
Expand Down Expand Up @@ -4450,6 +4481,7 @@ def test_kernel(self, kernel, make_input):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_functional(self, make_input):
Expand All @@ -4466,9 +4498,16 @@ def test_functional(self, make_input):
(F.resized_crop_mask, tv_tensors.Mask),
(F.resized_crop_video, tv_tensors.Video),
(F.resized_crop_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._resized_crop_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._resized_crop_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type)

@param_value_parametrization(
Expand All @@ -4485,6 +4524,7 @@ def test_functional_signature(self, kernel, input_type):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_transform(self, param, value, make_input):
Expand All @@ -4496,20 +4536,31 @@ def test_transform(self, param, value, make_input):

# `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
# The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
def test_functional_image_correctness(self, interpolation):
image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
def test_functional_image_correctness(self, make_input, interpolation):
image = make_input(self.INPUT_SIZE, dtype=torch.uint8)

actual = F.resized_crop(
image, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation, antialias=True
)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(
F.resized_crop(
F.to_pil_image(image), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation
)
)

torch.testing.assert_close(actual, expected, atol=1, rtol=0)
assert_close(actual, expected, atol=1, rtol=0)

def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size):
new_height, new_width = size
Expand Down Expand Up @@ -4920,6 +4971,7 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_functional(self, make_input):
Expand All @@ -4935,9 +4987,16 @@ def test_functional(self, make_input):
(F.center_crop_mask, tv_tensors.Mask),
(F.center_crop_video, tv_tensors.Video),
(F.center_crop_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._center_crop_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._center_crop_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
Expand All @@ -4950,17 +5009,29 @@ def test_functional_signature(self, kernel, input_type):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
def test_transform(self, make_input):
check_transform(transforms.CenterCrop(self.OUTPUT_SIZES[0]), make_input(self.INPUT_SIZE))

@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
def test_image_correctness(self, output_size, fn):
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
def test_image_correctness(self, output_size, make_input, fn):
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")

actual = fn(image, output_size)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))

assert_equal(actual, expected)
Expand Down Expand Up @@ -6235,7 +6306,13 @@ def wrapper(*args, **kwargs):

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("functional", [F.five_crop, F.ten_crop])
def test_functional(self, make_input, functional):
Expand All @@ -6253,13 +6330,27 @@ def test_functional(self, make_input, functional):
(F.five_crop, F._geometry._five_crop_image_pil, PIL.Image.Image),
(F.five_crop, F.five_crop_image, tv_tensors.Image),
(F.five_crop, F.five_crop_video, tv_tensors.Video),
pytest.param(
F.five_crop,
F._geometry._five_crop_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
(F.ten_crop, F.ten_crop_image, torch.Tensor),
(F.ten_crop, F._geometry._ten_crop_image_pil, PIL.Image.Image),
(F.ten_crop, F.ten_crop_image, tv_tensors.Image),
(F.ten_crop, F.ten_crop_video, tv_tensors.Video),
pytest.param(
F.ten_crop,
F._geometry._ten_crop_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, functional, kernel, input_type):
if kernel is F._geometry._five_crop_image_cvcuda or kernel is F._geometry._ten_crop_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type)

class _TransformWrapper(nn.Module):
Expand All @@ -6281,7 +6372,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
def test_transform(self, make_input, transform_cls):
Expand All @@ -6299,29 +6396,51 @@ def test_transform_error(self, make_input, transform_cls):
with pytest.raises(TypeError, match="not supported"):
transform(make_input(self.INPUT_SIZE))

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("fn", [F.five_crop, transform_cls_to_functional(transforms.FiveCrop)])
def test_correctness_image_five_crop(self, fn):
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
def test_correctness_image_five_crop(self, make_input, fn):
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")

actual = fn(image, size=self.OUTPUT_SIZE)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.five_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE)

assert isinstance(actual, tuple)
assert_equal(actual, [F.to_image(e) for e in expected])

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("fn_or_class", [F.ten_crop, transforms.TenCrop])
@pytest.mark.parametrize("vertical_flip", [False, True])
def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip):
def test_correctness_image_ten_crop(self, make_input, fn_or_class, vertical_flip):
if fn_or_class is transforms.TenCrop:
fn = transform_cls_to_functional(fn_or_class, size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
kwargs = dict()
else:
fn = fn_or_class
kwargs = dict(size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)

image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")

actual = fn(image, **kwargs)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.ten_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)

assert isinstance(actual, tuple)
Expand Down
12 changes: 12 additions & 0 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class CenterCrop(Transform):

_v1_transform_cls = _transforms.CenterCrop

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(self, size: Union[int, Sequence[int]]):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
Expand Down Expand Up @@ -252,6 +254,8 @@ class RandomResizedCrop(Transform):

_v1_transform_cls = _transforms.RandomResizedCrop

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(
self,
size: Union[int, Sequence[int]],
Expand Down Expand Up @@ -360,6 +364,8 @@ class FiveCrop(Transform):

_v1_transform_cls = _transforms.FiveCrop

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
Expand Down Expand Up @@ -404,6 +410,8 @@ class TenCrop(Transform):

_v1_transform_cls = _transforms.TenCrop

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
Expand Down Expand Up @@ -811,6 +819,8 @@ class RandomCrop(Transform):

_v1_transform_cls = _transforms.RandomCrop

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def _extract_params_for_v1_transform(self) -> dict[str, Any]:
params = super()._extract_params_for_v1_transform()

Expand Down Expand Up @@ -1121,6 +1131,8 @@ class RandomIoUCrop(Transform):
Default, 40.
"""

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(
self,
min_scale: float = 0.3,
Expand Down
Loading