@@ -6171,50 +6171,48 @@ def test_transform_invalid_quality_error(self, quality):
61716171 transforms .JPEG (quality = quality )
61726172
61736173
6174- class TestQuerySize :
6174+ class TestUtils :
6175+ # TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes
61756176 @pytest .mark .parametrize (
6176- "make_input, input_name" ,
6177- [
6178- (lambda : torch .rand (3 , 32 , 64 ), "pure_tensor" ),
6179- (lambda : tv_tensors .Image (torch .rand (3 , 32 , 64 )), "tv_tensor_image" ),
6180- (lambda : PIL .Image .new ("RGB" , (64 , 32 )), "pil_image" ),
6181- (lambda : tv_tensors .Video (torch .rand (1 , 3 , 32 , 64 )), "tv_tensor_video" ),
6182- (lambda : tv_tensors .Mask (torch .randint (0 , 2 , (32 , 64 ))), "tv_tensor_mask" ),
6183- ],
6184- ids = ["pure_tensor" , "tv_tensor_image" , "pil_image" , "tv_tensor_video" , "tv_tensor_mask" ],
6177+ "make_input1" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6178+ )
6179+ @pytest .mark .parametrize (
6180+ "make_input2" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
61856181 )
6186- def test_functional (self , make_input , input_name ):
6187- input1 = make_input ()
6188- input2 = make_input ()
6189- # Both inputs should have the same size (32, 64)
6190- assert transforms .query_size ([input1 , input2 ]) == (32 , 64 )
6182+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6183+ def test_query_size_and_query_chw (self , make_input1 , make_input2 , query ):
6184+ size = (32 , 64 )
6185+ input1 = make_input1 (size )
6186+ input2 = make_input2 (size )
6187+
6188+ if query is transforms .query_chw and not any (
6189+ transforms .check_type (inpt , (is_pure_tensor , tv_tensors .Image , PIL .Image .Image , tv_tensors .Video ))
6190+ for inpt in (input1 , input2 )
6191+ ):
6192+ return
6193+
6194+ expected = size if query is transforms .query_size else ((3 ,) + size )
6195+ assert query ([input1 , input2 ]) == expected
61916196
61926197 @pytest .mark .parametrize (
6193- "make_input, input_name" ,
6194- [
6195- (lambda : torch .rand (3 , 32 , 64 ), "pure_tensor" ),
6196- (lambda : tv_tensors .Image (torch .rand (3 , 32 , 64 )), "tv_tensor_image" ),
6197- (lambda : PIL .Image .new ("RGB" , (64 , 32 )), "pil_image" ),
6198- (lambda : tv_tensors .Video (torch .rand (1 , 3 , 32 , 64 )), "tv_tensor_video" ),
6199- (lambda : tv_tensors .Mask (torch .randint (0 , 2 , (32 , 64 ))), "tv_tensor_mask" ),
6200- ],
6201- ids = ["pure_tensor" , "tv_tensor_image" , "pil_image" , "tv_tensor_video" , "tv_tensor_mask" ],
6202- )
6203- def test_functional_mixed_types (self , make_input , input_name ):
6204- input1 = make_input ()
6205- input2 = make_input ()
6206- # Both inputs should have the same size (32, 64)
6207- assert transforms .query_size ([input1 , input2 ]) == (32 , 64 )
6208-
6209- def test_different_sizes (self ):
6210- img_tensor = torch .rand (3 , 32 , 64 ) # (C, H, W)
6211- img_tensor_different_size = torch .rand (3 , 48 , 96 ) # (C, H, W)
6212- # Should raise ValueError for different sizes
6213- with pytest .raises (ValueError , match = "Found multiple HxW dimensions" ):
6214- transforms .query_size ([img_tensor , img_tensor_different_size ])
6215-
6216- def test_no_valid_image (self ):
6217- invalid_input = torch .rand (1 , 10 ) # Non-image tensor
6218- # Should raise TypeError for invalid input
6219- with pytest .raises (TypeError , match = "No image, video, mask or bounding box was found" ):
6220- transforms .query_size ([invalid_input ])
6198+ "make_input1" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6199+ )
6200+ @pytest .mark .parametrize (
6201+ "make_input2" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6202+ )
6203+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6204+ def test_different_sizes (self , make_input1 , make_input2 , query ):
6205+ input1 = make_input1 ((10 , 10 ))
6206+ input2 = make_input2 ((20 , 20 ))
6207+ if query is transforms .query_chw and not all (
6208+ transforms .check_type (inpt , (is_pure_tensor , tv_tensors .Image , PIL .Image .Image , tv_tensors .Video ))
6209+ for inpt in (input1 , input2 )
6210+ ):
6211+ return
6212+ with pytest .raises (ValueError , match = "Found multiple" ):
6213+ query ([input1 , input2 ])
6214+
6215+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6216+ def test_no_valid_input (self , query ):
6217+ with pytest .raises (TypeError , match = "No image" ):
6218+ query (["blah" ])
0 commit comments