@@ -6169,3 +6169,52 @@ def test_transform_sequence_len_error(self, quality):
61696169 def test_transform_invalid_quality_error (self , quality ):
61706170 with pytest .raises (ValueError , match = "quality must be an integer from 1 to 100" ):
61716171 transforms .JPEG (quality = quality )
6172+
6173+
6174+ class TestQuerySize :
6175+ @pytest .mark .parametrize (
6176+ "make_input, input_name" ,
6177+ [
6178+ (lambda : torch .rand (3 , 32 , 64 ), "pure_tensor" ),
6179+ (lambda : tv_tensors .Image (torch .rand (3 , 32 , 64 )), "tv_tensor_image" ),
6180+ (lambda : PIL .Image .new ("RGB" , (64 , 32 )), "pil_image" ),
6181+ (lambda : tv_tensors .Video (torch .rand (1 , 3 , 32 , 64 )), "tv_tensor_video" ),
6182+ (lambda : tv_tensors .Mask (torch .randint (0 , 2 , (32 , 64 ))), "tv_tensor_mask" ),
6183+ ],
6184+ ids = ["pure_tensor" , "tv_tensor_image" , "pil_image" , "tv_tensor_video" , "tv_tensor_mask" ],
6185+ )
6186+ def test_functional (self , make_input , input_name ):
6187+ input1 = make_input ()
6188+ input2 = make_input ()
6189+ # Both inputs should have the same size (32, 64)
6190+ assert transforms .query_size ([input1 , input2 ]) == (32 , 64 )
6191+
6192+ @pytest .mark .parametrize (
6193+ "make_input, input_name" ,
6194+ [
6195+ (lambda : torch .rand (3 , 32 , 64 ), "pure_tensor" ),
6196+ (lambda : tv_tensors .Image (torch .rand (3 , 32 , 64 )), "tv_tensor_image" ),
6197+ (lambda : PIL .Image .new ("RGB" , (64 , 32 )), "pil_image" ),
6198+ (lambda : tv_tensors .Video (torch .rand (1 , 3 , 32 , 64 )), "tv_tensor_video" ),
6199+ (lambda : tv_tensors .Mask (torch .randint (0 , 2 , (32 , 64 ))), "tv_tensor_mask" ),
6200+ ],
6201+ ids = ["pure_tensor" , "tv_tensor_image" , "pil_image" , "tv_tensor_video" , "tv_tensor_mask" ],
6202+ )
6203+ def test_functional_mixed_types (self , make_input , input_name ):
6204+ input1 = make_input ()
6205+ input2 = make_input ()
6206+ # Both inputs should have the same size (32, 64)
6207+ assert transforms .query_size ([input1 , input2 ]) == (32 , 64 )
6208+
6209+ def test_different_sizes (self ):
6210+ img_tensor = torch .rand (3 , 32 , 64 ) # (C, H, W)
6211+ img_tensor_different_size = torch .rand (3 , 48 , 96 ) # (C, H, W)
6212+ # Should raise ValueError for different sizes
6213+ with pytest .raises (ValueError , match = "Found multiple HxW dimensions" ):
6214+ transforms .query_size ([img_tensor , img_tensor_different_size ])
6215+
6216+ def test_no_valid_image (self ):
6217+ invalid_input = torch .rand (1 , 10 ) # Non-image tensor
6218+ # Should raise TypeError for invalid input
6219+ with pytest .raises (TypeError , match = "No image, video, mask or bounding box was found" ):
6220+ transforms .query_size ([invalid_input ])
0 commit comments