Skip to content

Commit e100702

Browse files
committed
make_public_query_size
1 parent 9832166 commit e100702

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

test/test_transforms_v2.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6169,3 +6169,52 @@ def test_transform_sequence_len_error(self, quality):
61696169
def test_transform_invalid_quality_error(self, quality):
61706170
with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
61716171
transforms.JPEG(quality=quality)
6172+
6173+
6174+
class TestQuerySize:
6175+
@pytest.mark.parametrize(
6176+
"make_input, input_name",
6177+
[
6178+
(lambda: torch.rand(3, 32, 64), "pure_tensor"),
6179+
(lambda: tv_tensors.Image(torch.rand(3, 32, 64)), "tv_tensor_image"),
6180+
(lambda: PIL.Image.new("RGB", (64, 32)), "pil_image"),
6181+
(lambda: tv_tensors.Video(torch.rand(1, 3, 32, 64)), "tv_tensor_video"),
6182+
(lambda: tv_tensors.Mask(torch.randint(0, 2, (32, 64))), "tv_tensor_mask"),
6183+
],
6184+
ids=["pure_tensor", "tv_tensor_image", "pil_image", "tv_tensor_video", "tv_tensor_mask"],
6185+
)
6186+
def test_functional(self, make_input, input_name):
6187+
input1 = make_input()
6188+
input2 = make_input()
6189+
# Both inputs should have the same size (32, 64)
6190+
assert transforms.query_size([input1, input2]) == (32, 64)
6191+
6192+
@pytest.mark.parametrize(
6193+
"make_input, input_name",
6194+
[
6195+
(lambda: torch.rand(3, 32, 64), "pure_tensor"),
6196+
(lambda: tv_tensors.Image(torch.rand(3, 32, 64)), "tv_tensor_image"),
6197+
(lambda: PIL.Image.new("RGB", (64, 32)), "pil_image"),
6198+
(lambda: tv_tensors.Video(torch.rand(1, 3, 32, 64)), "tv_tensor_video"),
6199+
(lambda: tv_tensors.Mask(torch.randint(0, 2, (32, 64))), "tv_tensor_mask"),
6200+
],
6201+
ids=["pure_tensor", "tv_tensor_image", "pil_image", "tv_tensor_video", "tv_tensor_mask"],
6202+
)
6203+
def test_functional_mixed_types(self, make_input, input_name):
6204+
input1 = make_input()
6205+
input2 = make_input()
6206+
# Both inputs should have the same size (32, 64)
6207+
assert transforms.query_size([input1, input2]) == (32, 64)
6208+
6209+
def test_different_sizes(self):
6210+
img_tensor = torch.rand(3, 32, 64) # (C, H, W)
6211+
img_tensor_different_size = torch.rand(3, 48, 96) # (C, H, W)
6212+
# Should raise ValueError for different sizes
6213+
with pytest.raises(ValueError, match="Found multiple HxW dimensions"):
6214+
transforms.query_size([img_tensor, img_tensor_different_size])
6215+
6216+
def test_no_valid_image(self):
6217+
invalid_input = torch.rand(1, 10) # Non-image tensor
6218+
# Should raise TypeError for invalid input
6219+
with pytest.raises(TypeError, match="No image, video, mask or bounding box was found"):
6220+
transforms.query_size([invalid_input])

torchvision/transforms/v2/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,5 @@
151151
from ._type_conversion import pil_to_tensor, to_image, to_pil_image
152152

153153
from ._deprecated import get_image_size, to_tensor # usort: skip
154+
155+
from ._utils import query_size

0 commit comments

Comments
 (0)