@@ -451,56 +451,42 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
451451        torch.Tensor: Tensor of same shape as input containing the rectangle coordinates. 
452452                     The output maintains the same dtype as the input. 
453453    """ 
454-     out_boxes  =  parallelogram .clone ()
455- 
456-     # Calculate parallelogram diagonal vectors 
457-     dx13  =  parallelogram [..., 4 ] -  parallelogram [..., 0 ]
458-     dy13  =  parallelogram [..., 5 ] -  parallelogram [..., 1 ]
459-     dx42  =  parallelogram [..., 2 ] -  parallelogram [..., 6 ]
460-     dy42  =  parallelogram [..., 3 ] -  parallelogram [..., 7 ]
461-     dx12  =  parallelogram [..., 2 ] -  parallelogram [..., 0 ]
462-     dy12  =  parallelogram [..., 1 ] -  parallelogram [..., 3 ]
463-     diag13  =  torch .sqrt (dx13 ** 2  +  dy13 ** 2 )
464-     diag24  =  torch .sqrt (dx42 ** 2  +  dy42 ** 2 )
465- 
466-     # Calculate rotation angle in radians 
467-     r_rad  =  torch .atan2 (dy12 , dx12 )
468-     cos , sin  =  torch .cos (r_rad ), torch .sin (r_rad )
469- 
470-     # Calculate width using the angle between diagonal and rotation 
471-     w13  =  diag13  *  torch .abs (torch .sin (torch .atan2 (dx13 , dy13 ) -  r_rad ))
472-     delta_x13  =  w13  *  cos 
473-     delta_y13  =  w13  *  sin 
474-     w24  =  diag24  *  torch .abs (torch .sin (torch .atan2 (dx42 , dy42 ) -  r_rad ))
475-     delta_x24  =  w24  *  cos 
476-     delta_y24  =  w24  *  sin 
477- 
478-     # Calculate the area of the triangle formed by the three points 
479-     # Area = 1/2 * |det([x1, y1, 1], [x2, y2, 1], [x3, y3, 1])| 
480-     # For points (x1, y1), (x1 - delta_x, y1 + delta_y), (x3, y3) 
481-     # This simplifies to 1/2 * |delta_x * (y3 - y1) - delta_y * (x3 - x1)| 
482-     area13  =  0.5  *  torch .abs (delta_x13  *  dy13  -  delta_y13  *  dx13 )
483-     # For points (x4, y4), (x4 - delta_x, y4 + delta_y), (x2, y2) 
484-     # This simplifies to 1/2 * |delta_x * (y2 - y4) - delta_y * (x2 - x4)| 
485-     area24  =  0.5  *  torch .abs (delta_x24  *  dy42  -  delta_y24  *  dx42 )
486- 
487-     # We keep the rectangle with the smallest area 
488-     mask  =  area13  <  area24 
489-     delta_x  =  torch .where (mask , delta_x13 , delta_x24 )
490-     delta_y  =  torch .where (mask , delta_y13 , delta_y24 )
491- 
492-     # Update coordinates to form a rectangle 
493-     # Keeping the points (x1, y1) and (x3, y3) unchanged. 
494-     out_boxes [..., 2 ] =  torch .where (mask , parallelogram [..., 0 ] +  delta_x , parallelogram [..., 2 ])
495-     out_boxes [..., 3 ] =  torch .where (mask , parallelogram [..., 1 ] -  delta_y , parallelogram [..., 3 ])
496-     out_boxes [..., 6 ] =  torch .where (mask , parallelogram [..., 4 ] -  delta_x , parallelogram [..., 6 ])
497-     out_boxes [..., 7 ] =  torch .where (mask , parallelogram [..., 5 ] +  delta_y , parallelogram [..., 7 ])
498- 
499-     # Keeping the points (x2, y2) and (x4, y4) unchanged. 
500-     out_boxes [..., 0 ] =  torch .where (~ mask , parallelogram [..., 2 ] -  delta_x , parallelogram [..., 0 ])
501-     out_boxes [..., 1 ] =  torch .where (~ mask , parallelogram [..., 3 ] +  delta_y , parallelogram [..., 1 ])
502-     out_boxes [..., 4 ] =  torch .where (~ mask , parallelogram [..., 6 ] +  delta_x , parallelogram [..., 4 ])
503-     out_boxes [..., 5 ] =  torch .where (~ mask , parallelogram [..., 7 ] -  delta_y , parallelogram [..., 5 ])
454+     original_shape  =  parallelogram .shape 
455+     dtype  =  parallelogram .dtype 
456+     acceptable_dtypes  =  [torch .float32 , torch .float64 ]
457+     need_cast  =  dtype  not  in   acceptable_dtypes 
458+     if  need_cast :
459+         # Up-case to avoid overflow for square operations 
460+         parallelogram  =  parallelogram .to (torch .float32 )
461+ 
462+     x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4  =  parallelogram .unbind (- 1 )
463+     cx  =  (x1  +  x3 ) /  2 
464+     cy  =  (y1  +  y3 ) /  2 
465+ 
466+     # Calculate width, height, and rotation angle of the parallelogram 
467+     wp  =  torch .sqrt ((x2  -  x1 ) **  2  +  (y2  -  y1 ) **  2 )
468+     hp  =  torch .sqrt ((x4  -  x1 ) **  2  +  (y4  -  y1 ) **  2 )
469+     r12  =  torch .atan2 (y1  -  y2 , x2  -  x1 )
470+     r14  =  torch .atan2 (y1  -  y4 , x4  -  x1 )
471+     r_rad  =  r12  -  r14 
472+     sign  =  torch .where (r_rad  >  torch .pi  /  2 , - 1 , 1 )
473+     cos , sin  =  r_rad .cos (), r_rad .sin ()
474+ 
475+     # Calculate width, height, and rotation angle of the rectangle 
476+     w  =  torch .where (wp  <  hp , wp  *  sin , wp  +  hp  *  cos  *  sign )
477+     h  =  torch .where (wp  >  hp , hp  *  sin , hp  +  wp  *  cos  *  sign )
478+     r_rad  =  torch .where (hp  >  wp , r14  +  torch .pi  /  2 , r12 )
479+     cos , sin  =  r_rad .cos (), r_rad .sin ()
480+ 
481+     out_boxes  =  convert_bounding_box_format (
482+         torch .stack ((cx , cy , w , h , r_rad  *  180  /  torch .pi ), dim = - 1 ),
483+         old_format = tv_tensors .BoundingBoxFormat .CXCYWHR ,
484+         new_format = tv_tensors .BoundingBoxFormat .XYXYXYXY ,
485+         inplace = False ,
486+     ).reshape (original_shape )
487+ 
488+     if  need_cast :
489+         out_boxes  =  out_boxes .to (dtype )
504490    return  out_boxes 
505491
506492
0 commit comments