@@ -6169,3 +6169,50 @@ def test_transform_sequence_len_error(self, quality):
61696169 def test_transform_invalid_quality_error (self , quality ):
61706170 with pytest .raises (ValueError , match = "quality must be an integer from 1 to 100" ):
61716171 transforms .JPEG (quality = quality )
6172+
6173+
6174+ class TestUtils :
6175+ # TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes
6176+ @pytest .mark .parametrize (
6177+ "make_input1" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6178+ )
6179+ @pytest .mark .parametrize (
6180+ "make_input2" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6181+ )
6182+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6183+ def test_query_size_and_query_chw (self , make_input1 , make_input2 , query ):
6184+ size = (32 , 64 )
6185+ input1 = make_input1 (size )
6186+ input2 = make_input2 (size )
6187+
6188+ if query is transforms .query_chw and not any (
6189+ transforms .check_type (inpt , (is_pure_tensor , tv_tensors .Image , PIL .Image .Image , tv_tensors .Video ))
6190+ for inpt in (input1 , input2 )
6191+ ):
6192+ return
6193+
6194+ expected = size if query is transforms .query_size else ((3 ,) + size )
6195+ assert query ([input1 , input2 ]) == expected
6196+
6197+ @pytest .mark .parametrize (
6198+ "make_input1" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6199+ )
6200+ @pytest .mark .parametrize (
6201+ "make_input2" , [make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask ]
6202+ )
6203+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6204+ def test_different_sizes (self , make_input1 , make_input2 , query ):
6205+ input1 = make_input1 ((10 , 10 ))
6206+ input2 = make_input2 ((20 , 20 ))
6207+ if query is transforms .query_chw and not all (
6208+ transforms .check_type (inpt , (is_pure_tensor , tv_tensors .Image , PIL .Image .Image , tv_tensors .Video ))
6209+ for inpt in (input1 , input2 )
6210+ ):
6211+ return
6212+ with pytest .raises (ValueError , match = "Found multiple" ):
6213+ query ([input1 , input2 ])
6214+
6215+ @pytest .mark .parametrize ("query" , [transforms .query_size , transforms .query_chw ])
6216+ def test_no_valid_input (self , query ):
6217+ with pytest .raises (TypeError , match = "No image" ):
6218+ query (["blah" ])
0 commit comments