@@ -560,6 +560,71 @@ def affine_bounding_boxes(bounding_boxes):
560560 )
561561
562562
563+ def reference_affine_rotated_bounding_boxes_helper (bounding_boxes , * , affine_matrix , new_canvas_size = None , clamp = True ):
564+ format = bounding_boxes .format
565+ canvas_size = new_canvas_size or bounding_boxes .canvas_size
566+
567+ def affine_rotated_bounding_boxes (bounding_boxes ):
568+ dtype = bounding_boxes .dtype
569+ device = bounding_boxes .device
570+
571+ # Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1
572+ input_xyxyxyxy = F .convert_bounding_box_format (
573+ bounding_boxes .to (dtype = torch .float64 , device = "cpu" , copy = True ),
574+ old_format = format ,
575+ new_format = tv_tensors .BoundingBoxFormat .XYXYXYXY ,
576+ inplace = True ,
577+ )
578+ x1 , y1 , x3 , y3 , x2 , y2 , x4 , y4 = input_xyxyxyxy .squeeze (0 ).tolist ()
579+
580+ points = np .array (
581+ [
582+ [x1 , y1 , 1.0 ],
583+ [x3 , y3 , 1.0 ],
584+ [x2 , y2 , 1.0 ],
585+ [x4 , y4 , 1.0 ],
586+ ]
587+ )
588+ transformed_points = np .matmul (points , affine_matrix .astype (points .dtype ).T )
589+ output = torch .Tensor (
590+ [
591+ float (transformed_points [0 , 0 ]),
592+ float (transformed_points [0 , 1 ]),
593+ float (transformed_points [3 , 0 ]),
594+ float (transformed_points [3 , 1 ]),
595+ float (transformed_points [2 , 0 ]),
596+ float (transformed_points [2 , 1 ]),
597+ float (transformed_points [1 , 0 ]),
598+ float (transformed_points [1 , 1 ]),
599+ ]
600+ )
601+
602+ output = F .convert_bounding_box_format (
603+ output , old_format = tv_tensors .BoundingBoxFormat .XYXYXYXY , new_format = format
604+ )
605+
606+ if clamp :
607+ # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
608+ output = F .clamp_bounding_boxes (
609+ output ,
610+ format = format ,
611+ canvas_size = canvas_size ,
612+ )
613+ else :
614+ # We leave the bounding box as float64 so the caller gets the full precision to perform any additional
615+ # operation
616+ dtype = output .dtype
617+
618+ return output .to (dtype = dtype , device = device )
619+
620+ return tv_tensors .BoundingBoxes (
621+ torch .cat ([affine_rotated_bounding_boxes (b ) for b in bounding_boxes .reshape (- 1 , 5 if format != tv_tensors .BoundingBoxFormat .XYXYXYXY else 8 ).unbind ()], dim = 0 ).reshape (
622+ bounding_boxes .shape
623+ ),
624+ format = format ,
625+ canvas_size = canvas_size ,
626+ )
627+
563628class TestResize :
564629 INPUT_SIZE = (17 , 11 )
565630 OUTPUT_SIZES = [17 , [17 ], (17 ,), None , [12 , 13 ], (12 , 13 )]
@@ -1012,7 +1077,7 @@ class TestHorizontalFlip:
10121077 def test_kernel_image (self , dtype , device ):
10131078 check_kernel (F .horizontal_flip_image , make_image (dtype = dtype , device = device ))
10141079
1015- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
1080+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
10161081 @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
10171082 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
10181083 def test_kernel_bounding_boxes (self , format , dtype , device ):
@@ -1071,25 +1136,27 @@ def test_image_correctness(self, fn):
10711136
10721137 torch .testing .assert_close (actual , expected )
10731138
1074- def _reference_horizontal_flip_bounding_boxes (self , bounding_boxes ):
1139+ def _reference_horizontal_flip_bounding_boxes (self , bounding_boxes , format ):
10751140 affine_matrix = np .array (
10761141 [
10771142 [- 1 , 0 , bounding_boxes .canvas_size [1 ]],
10781143 [0 , 1 , 0 ],
10791144 ],
10801145 )
10811146
1147+ if tv_tensors .is_rotated_bounding_format (format ):
1148+ return reference_affine_rotated_bounding_boxes_helper (bounding_boxes , affine_matrix = affine_matrix )
10821149 return reference_affine_bounding_boxes_helper (bounding_boxes , affine_matrix = affine_matrix )
10831150
1084- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
1151+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
10851152 @pytest .mark .parametrize (
10861153 "fn" , [F .horizontal_flip , transform_cls_to_functional (transforms .RandomHorizontalFlip , p = 1 )]
10871154 )
10881155 def test_bounding_boxes_correctness (self , format , fn ):
10891156 bounding_boxes = make_bounding_boxes (format = format )
10901157
10911158 actual = fn (bounding_boxes )
1092- expected = self ._reference_horizontal_flip_bounding_boxes (bounding_boxes )
1159+ expected = self ._reference_horizontal_flip_bounding_boxes (bounding_boxes , format )
10931160
10941161 torch .testing .assert_close (actual , expected )
10951162
0 commit comments