@@ -88,15 +88,21 @@ def horizontal_flip_bounding_boxes(
8888 bounding_boxes [:, 0 ].sub_ (canvas_size [1 ]).neg_ ()
8989 elif format == tv_tensors .BoundingBoxFormat .XYXYXYXY :
9090 bounding_boxes [:, 0 ::2 ].sub_ (canvas_size [1 ]).neg_ ()
91- bounding_boxes = bounding_boxes [:, [0 , 1 , 6 , 7 , 4 , 5 , 2 , 3 ]]
91+ bounding_boxes = bounding_boxes [:, [2 , 3 , 0 , 1 , 6 , 7 , 4 , 5 ]]
9292 elif format == tv_tensors .BoundingBoxFormat .XYWHR :
93- bounding_boxes [:, 0 ].sub_ (canvas_size [1 ]).neg_ ()
94- bounding_boxes = bounding_boxes [:, [0 , 1 , 3 , 2 , 4 ]]
95- bounding_boxes [:, - 1 ].add_ (90 ).neg_ ()
93+
94+ dtype = bounding_boxes .dtype
95+ if not torch .is_floating_point (bounding_boxes ):
96+ # Casting to float to support cos and sin computations.
97+ bounding_boxes = bounding_boxes .to (torch .float64 )
98+ angle_rad = bounding_boxes [:, 4 ].mul (torch .pi ).div (180 )
99+ bounding_boxes [:, 0 ].add_ (bounding_boxes [:, 2 ].mul (angle_rad .cos ())).sub_ (canvas_size [1 ]).neg_ ()
100+ bounding_boxes [:, 1 ].sub_ (bounding_boxes [:, 2 ].mul (angle_rad .sin ()))
101+ bounding_boxes [:, 4 ].neg_ ()
102+ bounding_boxes = bounding_boxes .to (dtype )
96103 else : # format == tv_tensors.BoundingBoxFormat.CXCYWHR:
97104 bounding_boxes [:, 0 ].sub_ (canvas_size [1 ]).neg_ ()
98- bounding_boxes = bounding_boxes [:, [0 , 1 , 3 , 2 , 4 ]]
99- bounding_boxes [:, - 1 ].add_ (90 ).neg_ ()
105+ bounding_boxes [:, 4 ].neg_ ()
100106
101107 return bounding_boxes .reshape (shape )
102108
0 commit comments