@@ -2944,7 +2944,7 @@ def test_kernel_image(self, kwargs, dtype, device):
29442944 check_kernel (F .crop_image , make_image (self .INPUT_SIZE , dtype = dtype , device = device ), ** kwargs )
29452945
29462946 @pytest .mark .parametrize ("kwargs" , CORRECTNESS_CROP_KWARGS )
2947- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
2947+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
29482948 @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
29492949 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
29502950 def test_kernel_bounding_box (self , kwargs , format , dtype , device ):
@@ -3089,12 +3089,15 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w
30893089 [0 , 1 , - top ],
30903090 ],
30913091 )
3092- return reference_affine_bounding_boxes_helper (
3093- bounding_boxes , affine_matrix = affine_matrix , new_canvas_size = (height , width )
3092+ helper = (
3093+ reference_affine_rotated_bounding_boxes_helper
3094+ if tv_tensors .is_rotated_bounding_format (bounding_boxes .format )
3095+ else reference_affine_bounding_boxes_helper
30943096 )
3097+ return helper (bounding_boxes , affine_matrix = affine_matrix , new_canvas_size = (height , width ))
30953098
30963099 @pytest .mark .parametrize ("kwargs" , CORRECTNESS_CROP_KWARGS )
3097- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
3100+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
30983101 @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
30993102 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
31003103 def test_functional_bounding_box_correctness (self , kwargs , format , dtype , device ):
@@ -3107,7 +3110,7 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device
31073110 assert_equal (F .get_size (actual ), F .get_size (expected ))
31083111
31093112 @pytest .mark .parametrize ("output_size" , [(17 , 11 ), (11 , 17 ), (11 , 11 )])
3110- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
3113+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
31113114 @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
31123115 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
31133116 @pytest .mark .parametrize ("seed" , list (range (5 )))
@@ -3129,7 +3132,7 @@ def test_transform_bounding_boxes_correctness(self, output_size, format, dtype,
31293132
31303133 expected = self ._reference_crop_bounding_boxes (bounding_boxes , ** params )
31313134
3132- assert_equal (actual , expected )
3135+ assert_equal (actual , expected , atol = 1 , rtol = 0 )
31333136 assert_equal (F .get_size (actual ), F .get_size (expected ))
31343137
31353138 def test_errors (self ):
@@ -3864,13 +3867,19 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
38643867 )
38653868 affine_matrix = (resize_affine_matrix @ crop_affine_matrix )[:2 , :]
38663869
3867- return reference_affine_bounding_boxes_helper (
3870+ helper = (
3871+ reference_affine_rotated_bounding_boxes_helper
3872+ if tv_tensors .is_rotated_bounding_format (bounding_boxes .format )
3873+ else reference_affine_bounding_boxes_helper
3874+ )
3875+
3876+ return helper (
38683877 bounding_boxes ,
38693878 affine_matrix = affine_matrix ,
38703879 new_canvas_size = size ,
38713880 )
38723881
3873- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
3882+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
38743883 def test_functional_bounding_boxes_correctness (self , format ):
38753884 bounding_boxes = make_bounding_boxes (self .INPUT_SIZE , format = format )
38763885
@@ -3879,7 +3888,7 @@ def test_functional_bounding_boxes_correctness(self, format):
38793888 bounding_boxes , ** self .CROP_KWARGS , size = self .OUTPUT_SIZE
38803889 )
38813890
3882- assert_equal (actual , expected )
3891+ torch . testing . assert_close (actual , expected )
38833892 assert_equal (F .get_size (actual ), F .get_size (expected ))
38843893
38853894 def test_transform_errors_warnings (self ):
@@ -4101,7 +4110,7 @@ def test_kernel_image(self, output_size, dtype, device):
41014110 )
41024111
41034112 @pytest .mark .parametrize ("output_size" , OUTPUT_SIZES )
4104- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
4113+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
41054114 def test_kernel_bounding_boxes (self , output_size , format ):
41064115 bounding_boxes = make_bounding_boxes (self .INPUT_SIZE , format = format )
41074116 check_kernel (
@@ -4175,12 +4184,15 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
41754184 [0 , 1 , - top ],
41764185 ],
41774186 )
4178- return reference_affine_bounding_boxes_helper (
4179- bounding_boxes , affine_matrix = affine_matrix , new_canvas_size = output_size
4187+ helper = (
4188+ reference_affine_rotated_bounding_boxes_helper
4189+ if tv_tensors .is_rotated_bounding_format (bounding_boxes .format )
4190+ else reference_affine_bounding_boxes_helper
41804191 )
4192+ return helper (bounding_boxes , affine_matrix = affine_matrix , new_canvas_size = output_size )
41814193
41824194 @pytest .mark .parametrize ("output_size" , OUTPUT_SIZES )
4183- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
4195+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
41844196 @pytest .mark .parametrize ("dtype" , [torch .int64 , torch .float32 ])
41854197 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
41864198 @pytest .mark .parametrize ("fn" , [F .center_crop , transform_cls_to_functional (transforms .CenterCrop )])
@@ -4190,7 +4202,7 @@ def test_bounding_boxes_correctness(self, output_size, format, dtype, device, fn
41904202 actual = fn (bounding_boxes , output_size )
41914203 expected = self ._reference_center_crop_bounding_boxes (bounding_boxes , output_size )
41924204
4193- assert_equal (actual , expected )
4205+ torch . testing . assert_close (actual , expected )
41944206
41954207
41964208class TestPerspective :
0 commit comments