Skip to content

Commit 1f24d8f

Browse files
committed
Fix, better tests, expose more stuff
1 parent e100702 commit 1f24d8f

File tree

3 files changed

+42
-45
lines changed

3 files changed

+42
-45
lines changed

test/test_transforms_v2.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6171,50 +6171,48 @@ def test_transform_invalid_quality_error(self, quality):
61716171
transforms.JPEG(quality=quality)
61726172

61736173

6174-
class TestQuerySize:
6174+
class TestUtils:
6175+
# TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes
61756176
@pytest.mark.parametrize(
6176-
"make_input, input_name",
6177-
[
6178-
(lambda: torch.rand(3, 32, 64), "pure_tensor"),
6179-
(lambda: tv_tensors.Image(torch.rand(3, 32, 64)), "tv_tensor_image"),
6180-
(lambda: PIL.Image.new("RGB", (64, 32)), "pil_image"),
6181-
(lambda: tv_tensors.Video(torch.rand(1, 3, 32, 64)), "tv_tensor_video"),
6182-
(lambda: tv_tensors.Mask(torch.randint(0, 2, (32, 64))), "tv_tensor_mask"),
6183-
],
6184-
ids=["pure_tensor", "tv_tensor_image", "pil_image", "tv_tensor_video", "tv_tensor_mask"],
6177+
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6178+
)
6179+
@pytest.mark.parametrize(
6180+
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
61856181
)
6186-
def test_functional(self, make_input, input_name):
6187-
input1 = make_input()
6188-
input2 = make_input()
6189-
# Both inputs should have the same size (32, 64)
6190-
assert transforms.query_size([input1, input2]) == (32, 64)
6182+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6183+
def test_query_size_and_query_chw(self, make_input1, make_input2, query):
6184+
size = (32, 64)
6185+
input1 = make_input1(size)
6186+
input2 = make_input2(size)
6187+
6188+
if query is transforms.query_chw and not any(
6189+
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
6190+
for inpt in (input1, input2)
6191+
):
6192+
return
6193+
6194+
expected = size if query is transforms.query_size else ((3,) + size)
6195+
assert query([input1, input2]) == expected
61916196

61926197
@pytest.mark.parametrize(
6193-
"make_input, input_name",
6194-
[
6195-
(lambda: torch.rand(3, 32, 64), "pure_tensor"),
6196-
(lambda: tv_tensors.Image(torch.rand(3, 32, 64)), "tv_tensor_image"),
6197-
(lambda: PIL.Image.new("RGB", (64, 32)), "pil_image"),
6198-
(lambda: tv_tensors.Video(torch.rand(1, 3, 32, 64)), "tv_tensor_video"),
6199-
(lambda: tv_tensors.Mask(torch.randint(0, 2, (32, 64))), "tv_tensor_mask"),
6200-
],
6201-
ids=["pure_tensor", "tv_tensor_image", "pil_image", "tv_tensor_video", "tv_tensor_mask"],
6202-
)
6203-
def test_functional_mixed_types(self, make_input, input_name):
6204-
input1 = make_input()
6205-
input2 = make_input()
6206-
# Both inputs should have the same size (32, 64)
6207-
assert transforms.query_size([input1, input2]) == (32, 64)
6208-
6209-
def test_different_sizes(self):
6210-
img_tensor = torch.rand(3, 32, 64) # (C, H, W)
6211-
img_tensor_different_size = torch.rand(3, 48, 96) # (C, H, W)
6212-
# Should raise ValueError for different sizes
6213-
with pytest.raises(ValueError, match="Found multiple HxW dimensions"):
6214-
transforms.query_size([img_tensor, img_tensor_different_size])
6215-
6216-
def test_no_valid_image(self):
6217-
invalid_input = torch.rand(1, 10) # Non-image tensor
6218-
# Should raise TypeError for invalid input
6219-
with pytest.raises(TypeError, match="No image, video, mask or bounding box was found"):
6220-
transforms.query_size([invalid_input])
6198+
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6199+
)
6200+
@pytest.mark.parametrize(
6201+
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
6202+
)
6203+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6204+
def test_different_sizes(self, make_input1, make_input2, query):
6205+
input1 = make_input1((10, 10))
6206+
input2 = make_input2((20, 20))
6207+
if query is transforms.query_chw and not all(
6208+
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
6209+
for inpt in (input1, input2)
6210+
):
6211+
return
6212+
with pytest.raises(ValueError, match="Found multiple"):
6213+
query([input1, input2])
6214+
6215+
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
6216+
def test_no_valid_input(self, query):
6217+
with pytest.raises(TypeError, match="No image"):
6218+
query(["blah"])

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,6 @@
5555
)
5656
from ._temporal import UniformTemporalSubsample
5757
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
58+
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
5859

5960
from ._deprecated import ToTensor # usort: skip

torchvision/transforms/v2/functional/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,3 @@
151151
from ._type_conversion import pil_to_tensor, to_image, to_pil_image
152152

153153
from ._deprecated import get_image_size, to_tensor # usort: skip
154-
155-
from ._utils import query_size

0 commit comments

Comments
 (0)