@@ -2017,23 +2017,27 @@ def elastic_bounding_boxes(
20172017 # TODO: add in docstring about approximation we are doing for grid inversion
20182018 device = bounding_boxes .device
20192019 dtype = bounding_boxes .dtype if torch .is_floating_point (bounding_boxes ) else torch .float32
2020+ is_rotated = tv_tensors .is_rotated_bounding_format (format )
20202021
20212022 if displacement .dtype != dtype or displacement .device != device :
20222023 displacement = displacement .to (dtype = dtype , device = device )
20232024
20242025 original_shape = bounding_boxes .shape
20252026 # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
2027+ intermediate_format = tv_tensors .BoundingBoxFormat .XYXYXYXY if is_rotated else tv_tensors .BoundingBoxFormat .XYXY
2028+
20262029 bounding_boxes = (
2027- convert_bounding_box_format (bounding_boxes , old_format = format , new_format = tv_tensors . BoundingBoxFormat . XYXY )
2028- ).reshape (- 1 , 4 )
2030+ convert_bounding_box_format (bounding_boxes . clone () , old_format = format , new_format = intermediate_format )
2031+ ).reshape (- 1 , 8 if is_rotated else 4 )
20292032
20302033 id_grid = _create_identity_grid (canvas_size , device = device , dtype = dtype )
20312034 # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
20322035 # This is not an exact inverse of the grid
20332036 inv_grid = id_grid .sub_ (displacement )
20342037
20352038 # Get points from bboxes
2036- points = bounding_boxes [:, [[0 , 1 ], [2 , 1 ], [2 , 3 ], [0 , 3 ]]].reshape (- 1 , 2 )
2039+ points = bounding_boxes if is_rotated else bounding_boxes [:, [[0 , 1 ], [2 , 1 ], [2 , 3 ], [0 , 3 ]]]
2040+ points = points .reshape (- 1 , 2 )
20372041 if points .is_floating_point ():
20382042 points = points .ceil_ ()
20392043 index_xy = points .to (dtype = torch .long )
@@ -2043,16 +2047,22 @@ def elastic_bounding_boxes(
20432047 t_size = torch .tensor (canvas_size [::- 1 ], device = displacement .device , dtype = displacement .dtype )
20442048 transformed_points = inv_grid [0 , index_y , index_x , :].add_ (1 ).mul_ (0.5 * t_size ).sub_ (0.5 )
20452049
2046- transformed_points = transformed_points .reshape (- 1 , 4 , 2 )
2047- out_bbox_mins , out_bbox_maxs = torch .aminmax (transformed_points , dim = 1 )
2050+ if is_rotated :
2051+ transformed_points = transformed_points .reshape (- 1 , 8 )
2052+ out_bboxes = _parallelogram_to_bounding_boxes (transformed_points ).to (bounding_boxes .dtype )
2053+ else :
2054+ transformed_points = transformed_points .reshape (- 1 , 4 , 2 )
2055+ out_bbox_mins , out_bbox_maxs = torch .aminmax (transformed_points , dim = 1 )
2056+ out_bboxes = torch .cat ([out_bbox_mins , out_bbox_maxs ], dim = 1 ).to (bounding_boxes .dtype )
2057+
20482058 out_bboxes = clamp_bounding_boxes (
2049- torch . cat ([ out_bbox_mins , out_bbox_maxs ], dim = 1 ). to ( bounding_boxes . dtype ) ,
2050- format = tv_tensors . BoundingBoxFormat . XYXY ,
2059+ out_bboxes ,
2060+ format = intermediate_format ,
20512061 canvas_size = canvas_size ,
20522062 )
20532063
20542064 return convert_bounding_box_format (
2055- out_bboxes , old_format = tv_tensors . BoundingBoxFormat . XYXY , new_format = format , inplace = True
2065+ out_bboxes , old_format = intermediate_format , new_format = format , inplace = False
20562066 ).reshape (original_shape )
20572067
20582068
0 commit comments