Skip to content

Commit 37c72e4

Browse files
Fix _parallelogram_to_bounding_boxes
1 parent b818d32 commit 37c72e4

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -462,21 +462,34 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
462462
dy12 = parallelogram[..., 1] - parallelogram[..., 3]
463463
diag13 = torch.sqrt(dx13**2 + dy13**2)
464464
diag24 = torch.sqrt(dx42**2 + dy42**2)
465-
mask = diag13 > diag24
466465

467466
# Calculate rotation angle in radians
468467
r_rad = torch.atan2(dy12, dx12)
469468
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
470469

471470
# Calculate width using the angle between diagonal and rotation
472-
w = torch.where(
473-
mask,
474-
diag13 * torch.abs(torch.sin(torch.atan2(dx13, dy13) - r_rad)),
475-
diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad)),
476-
)
471+
w13 = diag13 * torch.abs(torch.sin(torch.atan2(dx13, dy13) - r_rad))
472+
delta_x13 = w13 * cos
473+
delta_y13 = w13 * sin
474+
w24 = diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad))
475+
delta_x24 = w24 * cos
476+
delta_y24 = w24 * sin
477+
478+
# Calculate the area of the triangle formed by the three points
479+
# Area = 1/2 * |det([x1, y1, 1], [x2, y2, 1], [x3, y3, 1])|
480+
# For points (x1, y1), (x1 - delta_x, y1 + delta_y), (x3, y3)
481+
# This simplifies to 1/2 * |delta_x * (y3 - y1) - delta_y * (x3 - x1)|
482+
area13 = 0.5 * torch.abs(delta_x13 * dy13 - delta_y13 * dx13)
483+
# For points (x4, y4), (x4 - delta_x, y4 + delta_y), (x2, y2)
484+
# This simplifies to 1/2 * |delta_x * (y2 - y4) - delta_y * (x2 - x4)|
485+
area24 = 0.5 * torch.abs(delta_x24 * dy42 - delta_y24 * dx42)
486+
487+
# We keep the rectangle with the smallest area
488+
mask = area13 < area24
489+
delta_x = torch.where(mask, delta_x13, delta_x24)
490+
delta_y = torch.where(mask, delta_y13, delta_y24)
491+
477492

478-
delta_x = w * cos
479-
delta_y = w * sin
480493
# Update coordinates to form a rectangle
481494
# Keeping the points (x1, y1) and (x3, y3) unchanged.
482495
out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2])

0 commit comments

Comments
 (0)