@@ -451,54 +451,45 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
451451        torch.Tensor: Tensor of same shape as input containing the rectangle coordinates. 
452452                     The output maintains the same dtype as the input. 
453453    """ 
454+     original_shape  =  parallelogram .shape 
454455    dtype  =  parallelogram .dtype 
455456    acceptable_dtypes  =  [torch .float32 , torch .float64 ]
456457    need_cast  =  dtype  not  in acceptable_dtypes 
457458    if  need_cast :
458459        # Up-case to avoid overflow for square operations 
459460        parallelogram  =  parallelogram .to (torch .float32 )
460-     out_boxes  =  parallelogram .clone ()
461- 
462-     # Calculate parallelogram diagonal vectors 
463-     dx13  =  parallelogram [..., 4 ] -  parallelogram [..., 0 ]
464-     dy13  =  parallelogram [..., 5 ] -  parallelogram [..., 1 ]
465-     dx42  =  parallelogram [..., 2 ] -  parallelogram [..., 6 ]
466-     dy42  =  parallelogram [..., 3 ] -  parallelogram [..., 7 ]
467-     dx12  =  parallelogram [..., 2 ] -  parallelogram [..., 0 ]
468-     dy12  =  parallelogram [..., 1 ] -  parallelogram [..., 3 ]
469-     diag13  =  torch .sqrt (dx13 ** 2  +  dy13 ** 2 )
470-     diag24  =  torch .sqrt (dx42 ** 2  +  dy42 ** 2 )
471-     mask  =  diag13  >  diag24 
472- 
473-     # Calculate rotation angle in radians 
474-     r_rad  =  torch .atan2 (dy12 , dx12 )
475-     cos , sin  =  torch .cos (r_rad ), torch .sin (r_rad )
476- 
477-     # Calculate width using the angle between diagonal and rotation 
478-     w  =  torch .where (
479-         mask ,
480-         diag13  *  torch .abs (torch .sin (torch .atan2 (dx13 , dy13 ) -  r_rad )),
481-         diag24  *  torch .abs (torch .sin (torch .atan2 (dx42 , dy42 ) -  r_rad )),
482-     )
483461
484-     delta_x  =  w  *  cos 
485-     delta_y  =  w  *  sin 
486-     # Update coordinates to form a rectangle 
487-     # Keeping the points (x1, y1) and (x3, y3) unchanged. 
488-     out_boxes [..., 2 ] =  torch .where (mask , parallelogram [..., 0 ] +  delta_x , parallelogram [..., 2 ])
489-     out_boxes [..., 3 ] =  torch .where (mask , parallelogram [..., 1 ] -  delta_y , parallelogram [..., 3 ])
490-     out_boxes [..., 6 ] =  torch .where (mask , parallelogram [..., 4 ] -  delta_x , parallelogram [..., 6 ])
491-     out_boxes [..., 7 ] =  torch .where (mask , parallelogram [..., 5 ] +  delta_y , parallelogram [..., 7 ])
492- 
493-     # Keeping the points (x2, y2) and (x4, y4) unchanged. 
494-     out_boxes [..., 0 ] =  torch .where (~ mask , parallelogram [..., 2 ] -  delta_x , parallelogram [..., 0 ])
495-     out_boxes [..., 1 ] =  torch .where (~ mask , parallelogram [..., 3 ] +  delta_y , parallelogram [..., 1 ])
496-     out_boxes [..., 4 ] =  torch .where (~ mask , parallelogram [..., 6 ] +  delta_x , parallelogram [..., 4 ])
497-     out_boxes [..., 5 ] =  torch .where (~ mask , parallelogram [..., 7 ] -  delta_y , parallelogram [..., 5 ])
462+     x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4  =  parallelogram .unbind (- 1 )
463+     cx  =  (x1  +  x3 ) /  2 
464+     cy  =  (y1  +  y3 ) /  2 
465+ 
466+     # Calculate width, height, and rotation angle of the parallelogram 
467+     wp  =  torch .sqrt ((x2  -  x1 ) **  2  +  (y2  -  y1 ) **  2 )
468+     hp  =  torch .sqrt ((x4  -  x1 ) **  2  +  (y4  -  y1 ) **  2 )
469+     r12  =  torch .atan2 (y1  -  y2 , x2  -  x1 )
470+     r14  =  torch .atan2 (y1  -  y4 , x4  -  x1 )
471+     r_rad  =  r12  -  r14 
472+     sign  =  torch .where (r_rad  >  torch .pi  /  2 , - 1 , 1 )
473+     cos , sin  =  r_rad .cos (), r_rad .sin ()
474+ 
475+     # Calculate width, height, and rotation angle of the rectangle 
476+     w  =  torch .where (wp  <  hp , wp  *  sin , wp  +  hp  *  cos  *  sign )
477+     h  =  torch .where (wp  >  hp , hp  *  sin , hp  +  wp  *  cos  *  sign )
478+     r_rad  =  torch .where (hp  >  wp , r14  +  torch .pi  /  2 , r12 )
479+     cos , sin  =  r_rad .cos (), r_rad .sin ()
480+ 
481+     x1  =  cx  -  w  /  2  *  cos  -  h  /  2  *  sin 
482+     y1  =  cy  -  h  /  2  *  cos  +  w  /  2  *  sin 
483+     x2  =  cx  +  w  /  2  *  cos  -  h  /  2  *  sin 
484+     y2  =  cy  -  h  /  2  *  cos  -  w  /  2  *  sin 
485+     x3  =  cx  +  w  /  2  *  cos  +  h  /  2  *  sin 
486+     y3  =  cy  +  h  /  2  *  cos  -  w  /  2  *  sin 
487+     x4  =  cx  -  w  /  2  *  cos  +  h  /  2  *  sin 
488+     y4  =  cy  +  h  /  2  *  cos  +  w  /  2  *  sin 
489+     out_boxes  =  torch .stack ((x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 ), dim = - 1 ).reshape (original_shape )
498490
499491    if  need_cast :
500492        out_boxes  =  out_boxes .to (dtype )
501- 
502493    return  out_boxes 
503494
504495
0 commit comments