@@ -383,11 +383,10 @@ def _resize_mask_dispatch(
383383
384384def _parallelogram_to_bounding_boxes (parallelogram : torch .Tensor ) -> torch .Tensor :
385385 """
386- Convert a parallelogram to a rectangle while keeping the points (x1, y1) and (x3, y3) unchanged.
387-
386+ Convert a parallelogram to a rectangle while keeping two points unchanged.
388387 This function transforms a parallelogram represented by 8 coordinates (4 points) into a rectangle.
389- The first point (x1, y1) and the third point (x3, y3) of the parallelogram remain fixed,
390- while the second and fourth points are adjusted to form a proper rectangle.
388+ The two diagonally opposed points of the parallelogram forming the longest diagonal remain fixed.
389+ The other points are adjusted to form a proper rectangle.
391390
392391 Note:
393392 This function is not applied in-place and will return a copy of the input tensor.
@@ -401,34 +400,52 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
401400 The output maintains the same dtype as the input.
402401 """
403402 dtype = parallelogram .dtype
404- int_dtype = dtype in (torch .uint8 ,
405- torch .int8 ,
406- torch .int16 ,
407- torch .int32 ,
408- torch .int64 ,
409- )
403+ int_dtype = dtype in (
404+ torch .uint8 ,
405+ torch .int8 ,
406+ torch .int16 ,
407+ torch .int32 ,
408+ torch .int64 ,
409+ )
410410
411- # Calculate diagonal vector from first to third point
412- dx = parallelogram [..., 4 ] - parallelogram [..., 0 ]
413- dy = parallelogram [..., 5 ] - parallelogram [..., 1 ]
414- diag = torch .sqrt (dx ** 2 + dy ** 2 )
411+ out_boxes = parallelogram .clone ()
412+
413+ # Calculate parallelogram diagonal vectors
414+ dx13 = parallelogram [..., 4 ] - parallelogram [..., 0 ]
415+ dy13 = parallelogram [..., 5 ] - parallelogram [..., 1 ]
416+ dx42 = parallelogram [..., 2 ] - parallelogram [..., 6 ]
417+ dy42 = parallelogram [..., 3 ] - parallelogram [..., 7 ]
418+ diag13 = torch .sqrt (dx13 ** 2 + dy13 ** 2 )
419+ diag24 = torch .sqrt (dx42 ** 2 + dy42 ** 2 )
420+ mask = diag13 > diag24
415421
416422 # Calculate rotation angle in radians
417423 r_rad = torch .atan2 (parallelogram [..., 1 ] - parallelogram [..., 3 ], parallelogram [..., 2 ] - parallelogram [..., 0 ])
418424 cos , sin = torch .cos (r_rad ), torch .sin (r_rad )
419425
420426 # Calculate width using the angle between diagonal and rotation
421- w = diag * torch .abs (torch .sin (torch .atan2 (dx , dy ) - r_rad ))
427+ w = torch .where (
428+ mask ,
429+ diag13 * torch .abs (torch .sin (torch .atan2 (dx13 , dy13 ) - r_rad )),
430+ diag24 * torch .abs (torch .sin (torch .atan2 (dx42 , dy42 ) - r_rad )),
431+ )
422432
423433 delta_x = torch .round (w * cos ).to (dtype ) if int_dtype else w * cos
424- detla_y = torch .round (w * sin ).to (dtype ) if int_dtype else w * sin
434+ delta_y = torch .round (w * sin ).to (dtype ) if int_dtype else w * sin
425435
426436 # Update coordinates to form a rectangle
427- parallelogram [..., 2 ] = parallelogram [..., 0 ] + delta_x
428- parallelogram [..., 3 ] = parallelogram [..., 1 ] - detla_y
429- parallelogram [..., 6 ] = parallelogram [..., 4 ] - delta_x
430- parallelogram [..., 7 ] = parallelogram [..., 5 ] + detla_y
431- return parallelogram
437+ # Keeping the points (x1, y1) and (x3, y3) unchanged.
438+ out_boxes [..., 2 ] = torch .where (mask , parallelogram [..., 0 ] + delta_x , parallelogram [..., 2 ])
439+ out_boxes [..., 3 ] = torch .where (mask , parallelogram [..., 1 ] - delta_y , parallelogram [..., 3 ])
440+ out_boxes [..., 6 ] = torch .where (mask , parallelogram [..., 4 ] - delta_x , parallelogram [..., 6 ])
441+ out_boxes [..., 7 ] = torch .where (mask , parallelogram [..., 5 ] + delta_y , parallelogram [..., 7 ])
442+
443+ # Keeping the points (x2, y2) and (x4, y4) unchanged.
444+ out_boxes [..., 0 ] = torch .where (~ mask , parallelogram [..., 2 ] - delta_x , parallelogram [..., 0 ])
445+ out_boxes [..., 1 ] = torch .where (~ mask , parallelogram [..., 3 ] + delta_y , parallelogram [..., 1 ])
446+ out_boxes [..., 4 ] = torch .where (~ mask , parallelogram [..., 6 ] + delta_x , parallelogram [..., 4 ])
447+ out_boxes [..., 5 ] = torch .where (~ mask , parallelogram [..., 7 ] - delta_y , parallelogram [..., 5 ])
448+ return out_boxes
432449
433450
434451def resize_bounding_boxes (
0 commit comments