@@ -560,6 +560,78 @@ def affine_bounding_boxes(bounding_boxes):
560560 )
561561
562562
563+ def reference_affine_rotated_bounding_boxes_helper (bounding_boxes , * , affine_matrix , new_canvas_size = None , clamp = True ):
564+ format = bounding_boxes .format
565+ canvas_size = new_canvas_size or bounding_boxes .canvas_size
566+
567+ def affine_rotated_bounding_boxes (bounding_boxes ):
568+ dtype = bounding_boxes .dtype
569+ device = bounding_boxes .device
570+
571+ # Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1
572+ input_xyxyxyxy = F .convert_bounding_box_format (
573+ bounding_boxes .to (dtype = torch .float64 , device = "cpu" , copy = True ),
574+ old_format = format ,
575+ new_format = tv_tensors .BoundingBoxFormat .XYXYXYXY ,
576+ inplace = True ,
577+ )
578+ x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 = input_xyxyxyxy .squeeze (0 ).tolist ()
579+
580+ points = np .array (
581+ [
582+ [x1 , y1 , 1.0 ],
583+ [x2 , y2 , 1.0 ],
584+ [x3 , y3 , 1.0 ],
585+ [x4 , y4 , 1.0 ],
586+ ]
587+ )
588+ transformed_points = np .matmul (points , affine_matrix .astype (points .dtype ).T )
589+ output = torch .tensor (
590+ [
591+ float (transformed_points [1 , 0 ]),
592+ float (transformed_points [1 , 1 ]),
593+ float (transformed_points [0 , 0 ]),
594+ float (transformed_points [0 , 1 ]),
595+ float (transformed_points [3 , 0 ]),
596+ float (transformed_points [3 , 1 ]),
597+ float (transformed_points [2 , 0 ]),
598+ float (transformed_points [2 , 1 ]),
599+ ]
600+ )
601+
602+ output = F .convert_bounding_box_format (
603+ output , old_format = tv_tensors .BoundingBoxFormat .XYXYXYXY , new_format = format
604+ )
605+
606+ if clamp :
607+ # It is important to clamp before casting, especially for CXCYWHR format, dtype=int64
608+ output = F .clamp_bounding_boxes (
609+ output ,
610+ format = format ,
611+ canvas_size = canvas_size ,
612+ )
613+ else :
614+ # We leave the bounding box as float32 so the caller gets the full precision to perform any additional
615+ # operation
616+ dtype = output .dtype
617+
618+ return output .to (dtype = dtype , device = device )
619+
620+ return tv_tensors .BoundingBoxes (
621+ torch .cat (
622+ [
623+ affine_rotated_bounding_boxes (b )
624+ for b in bounding_boxes .reshape (
625+ - 1 , 5 if format != tv_tensors .BoundingBoxFormat .XYXYXYXY else 8
626+ ).unbind ()
627+ ],
628+ dim = 0 ,
629+ ).reshape (bounding_boxes .shape ),
630+ format = format ,
631+ canvas_size = canvas_size ,
632+ )
633+
634+
563635class TestResize :
564636 INPUT_SIZE = (17 , 11 )
565637 OUTPUT_SIZES = [17 , [17 ], (17 ,), None , [12 , 13 ], (12 , 13 )]
@@ -1012,7 +1084,7 @@ class TestHorizontalFlip:
10121084 def test_kernel_image (self , dtype , device ):
10131085 check_kernel (F .horizontal_flip_image , make_image (dtype = dtype , device = device ))
10141086
1015- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
1087+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
10161088 @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
10171089 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
10181090 def test_kernel_bounding_boxes (self , format , dtype , device ):
@@ -1071,17 +1143,22 @@ def test_image_correctness(self, fn):
10711143
10721144 torch .testing .assert_close (actual , expected )
10731145
1074- def _reference_horizontal_flip_bounding_boxes (self , bounding_boxes ):
1146+ def _reference_horizontal_flip_bounding_boxes (self , bounding_boxes : tv_tensors . BoundingBoxes ):
10751147 affine_matrix = np .array (
10761148 [
10771149 [- 1 , 0 , bounding_boxes .canvas_size [1 ]],
10781150 [0 , 1 , 0 ],
10791151 ],
10801152 )
10811153
1082- return reference_affine_bounding_boxes_helper (bounding_boxes , affine_matrix = affine_matrix )
1154+ helper = (
1155+ reference_affine_rotated_bounding_boxes_helper
1156+ if tv_tensors .is_rotated_bounding_format (bounding_boxes .format )
1157+ else reference_affine_bounding_boxes_helper
1158+ )
1159+ return helper (bounding_boxes , affine_matrix = affine_matrix )
10831160
1084- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
1161+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
10851162 @pytest .mark .parametrize (
10861163 "fn" , [F .horizontal_flip , transform_cls_to_functional (transforms .RandomHorizontalFlip , p = 1 )]
10871164 )
@@ -1464,7 +1541,7 @@ class TestVerticalFlip:
14641541 def test_kernel_image (self , dtype , device ):
14651542 check_kernel (F .vertical_flip_image , make_image (dtype = dtype , device = device ))
14661543
1467- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
1544+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
14681545 @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .int64 ])
14691546 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
14701547 def test_kernel_bounding_boxes (self , format , dtype , device ):
@@ -1521,17 +1598,22 @@ def test_image_correctness(self, fn):
15211598
15221599 torch .testing .assert_close (actual , expected )
15231600
1524- def _reference_vertical_flip_bounding_boxes (self , bounding_boxes ):
1601+ def _reference_vertical_flip_bounding_boxes (self , bounding_boxes : tv_tensors . BoundingBoxes ):
15251602 affine_matrix = np .array (
15261603 [
15271604 [1 , 0 , 0 ],
15281605 [0 , - 1 , bounding_boxes .canvas_size [0 ]],
15291606 ],
15301607 )
15311608
1532- return reference_affine_bounding_boxes_helper (bounding_boxes , affine_matrix = affine_matrix )
1609+ helper = (
1610+ reference_affine_rotated_bounding_boxes_helper
1611+ if tv_tensors .is_rotated_bounding_format (bounding_boxes .format )
1612+ else reference_affine_bounding_boxes_helper
1613+ )
1614+ return helper (bounding_boxes , affine_matrix = affine_matrix )
15331615
1534- @pytest .mark .parametrize ("format" , SUPPORTED_BOX_FORMATS )
1616+ @pytest .mark .parametrize ("format" , list ( tv_tensors . BoundingBoxFormat ) )
15351617 @pytest .mark .parametrize ("fn" , [F .vertical_flip , transform_cls_to_functional (transforms .RandomVerticalFlip , p = 1 )])
15361618 def test_bounding_boxes_correctness (self , format , fn ):
15371619 bounding_boxes = make_bounding_boxes (format = format )
0 commit comments