@@ -893,11 +893,15 @@ def _affine_bounding_boxes_with_expand(
893893 bounding_boxes = bounding_boxes .clone () if bounding_boxes .is_floating_point () else bounding_boxes .float ()
894894 dtype = bounding_boxes .dtype
895895 device = bounding_boxes .device
896+ intermediate_format = (
897+ tv_tensors .BoundingBoxFormat .XYXYXYXY
898+ if tv_tensors .is_rotated_bounding_format (format )
899+ else tv_tensors .BoundingBoxFormat .XYXY
900+ )
901+ intermediate_shape = 8 if tv_tensors .is_rotated_bounding_format (format ) else 4
896902 bounding_boxes = (
897- convert_bounding_box_format (
898- bounding_boxes , old_format = format , new_format = tv_tensors .BoundingBoxFormat .XYXY , inplace = True
899- )
900- ).reshape (- 1 , 4 )
903+ convert_bounding_box_format (bounding_boxes , old_format = format , new_format = intermediate_format , inplace = True )
904+ ).reshape (- 1 , intermediate_shape )
901905
902906 angle , translate , shear , center = _affine_parse_args (
903907 angle , translate , scale , shear , InterpolationMode .NEAREST , center
@@ -921,15 +925,22 @@ def _affine_bounding_boxes_with_expand(
921925 # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
922926 # Single point structure is similar to
923927 # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
924- points = bounding_boxes [:, [[0 , 1 ], [2 , 1 ], [2 , 3 ], [0 , 3 ]]].reshape (- 1 , 2 )
928+ if tv_tensors .is_rotated_bounding_format (format ):
929+ points = bounding_boxes .reshape (- 1 , 2 )
930+ else :
931+ points = bounding_boxes [:, [[0 , 1 ], [2 , 1 ], [2 , 3 ], [0 , 3 ]]].reshape (- 1 , 2 )
925932 points = torch .cat ([points , torch .ones (points .shape [0 ], 1 , device = device , dtype = dtype )], dim = - 1 )
926933 # 2) Now let's transform the points using affine matrix
927934 transformed_points = torch .matmul (points , transposed_affine_matrix )
928935 # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
929936 # and compute bounding box from 4 transformed points:
930- transformed_points = transformed_points .reshape (- 1 , 4 , 2 )
931- out_bbox_mins , out_bbox_maxs = torch .aminmax (transformed_points , dim = 1 )
932- out_bboxes = torch .cat ([out_bbox_mins , out_bbox_maxs ], dim = 1 )
937+ if tv_tensors .is_rotated_bounding_format (format ):
938+ transformed_points = transformed_points .reshape (- 1 , 8 )
939+ out_bboxes = _parallelogram_to_bounding_boxes (transformed_points )
940+ else :
941+ transformed_points = transformed_points .reshape (- 1 , 4 , 2 )
942+ out_bbox_mins , out_bbox_maxs = torch .aminmax (transformed_points , dim = 1 )
943+ out_bboxes = torch .cat ([out_bbox_mins , out_bbox_maxs ], dim = 1 )
933944
934945 if expand :
935946 # Compute minimum point for transformed image frame:
@@ -954,9 +965,9 @@ def _affine_bounding_boxes_with_expand(
954965 new_width , new_height = _compute_affine_output_size (affine_vector , width , height )
955966 canvas_size = (new_height , new_width )
956967
957- out_bboxes = clamp_bounding_boxes (out_bboxes , format = tv_tensors . BoundingBoxFormat . XYXY , canvas_size = canvas_size )
968+ out_bboxes = clamp_bounding_boxes (out_bboxes , format = intermediate_format , canvas_size = canvas_size )
958969 out_bboxes = convert_bounding_box_format (
959- out_bboxes , old_format = tv_tensors . BoundingBoxFormat . XYXY , new_format = format , inplace = True
970+ out_bboxes , old_format = intermediate_format , new_format = format , inplace = True
960971 ).reshape (original_shape )
961972
962973 out_bboxes = out_bboxes .to (original_dtype )
0 commit comments