Skip to content

Commit ee27585

Browse files
fix all cases for _parallelogram_to_bounding_boxes
1 parent 0a2a419 commit ee27585

File tree

2 files changed

+51
-50
lines changed

2 files changed

+51
-50
lines changed

test/test_transforms_v2.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7001,6 +7001,21 @@ def test_parallelogram_to_bounding_boxes(input_size, device):
70017001
actual = _parallelogram_to_bounding_boxes(parallelogram)
70027002
torch.testing.assert_close(actual, expected)
70037003

7004+
# Test the transformation of a simple parallelogram.
7005+
# 1
7006+
# 1-2 / 2
7007+
# / / -> / /
7008+
# 4-3 4 /
7009+
# 3
7010+
parallelogram = torch.tensor(
7011+
[[0, 4, 3, 1, 5, 1, 2, 4]],
7012+
dtype=torch.float32,
7013+
)
7014+
expected = torch.tensor(
7015+
[[0, 4, 4, 0, 5, 1, 1, 5]],
7016+
dtype=torch.float32,
7017+
)
7018+
70047019

70057020
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
70067021
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 36 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -451,56 +451,42 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
451451
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
452452
The output maintains the same dtype as the input.
453453
"""
454-
out_boxes = parallelogram.clone()
455-
456-
# Calculate parallelogram diagonal vectors
457-
dx13 = parallelogram[..., 4] - parallelogram[..., 0]
458-
dy13 = parallelogram[..., 5] - parallelogram[..., 1]
459-
dx42 = parallelogram[..., 2] - parallelogram[..., 6]
460-
dy42 = parallelogram[..., 3] - parallelogram[..., 7]
461-
dx12 = parallelogram[..., 2] - parallelogram[..., 0]
462-
dy12 = parallelogram[..., 1] - parallelogram[..., 3]
463-
diag13 = torch.sqrt(dx13**2 + dy13**2)
464-
diag24 = torch.sqrt(dx42**2 + dy42**2)
465-
466-
# Calculate rotation angle in radians
467-
r_rad = torch.atan2(dy12, dx12)
468-
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
469-
470-
# Calculate width using the angle between diagonal and rotation
471-
w13 = diag13 * torch.abs(torch.sin(torch.atan2(dx13, dy13) - r_rad))
472-
delta_x13 = w13 * cos
473-
delta_y13 = w13 * sin
474-
w24 = diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad))
475-
delta_x24 = w24 * cos
476-
delta_y24 = w24 * sin
477-
478-
# Calculate the area of the triangle formed by the three points
479-
# Area = 1/2 * |det([x1, y1, 1], [x2, y2, 1], [x3, y3, 1])|
480-
# For points (x1, y1), (x1 - delta_x, y1 + delta_y), (x3, y3)
481-
# This simplifies to 1/2 * |delta_x * (y3 - y1) - delta_y * (x3 - x1)|
482-
area13 = 0.5 * torch.abs(delta_x13 * dy13 - delta_y13 * dx13)
483-
# For points (x4, y4), (x4 - delta_x, y4 + delta_y), (x2, y2)
484-
# This simplifies to 1/2 * |delta_x * (y2 - y4) - delta_y * (x2 - x4)|
485-
area24 = 0.5 * torch.abs(delta_x24 * dy42 - delta_y24 * dx42)
486-
487-
# We keep the rectangle with the smallest area
488-
mask = area13 < area24
489-
delta_x = torch.where(mask, delta_x13, delta_x24)
490-
delta_y = torch.where(mask, delta_y13, delta_y24)
491-
492-
# Update coordinates to form a rectangle
493-
# Keeping the points (x1, y1) and (x3, y3) unchanged.
494-
out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2])
495-
out_boxes[..., 3] = torch.where(mask, parallelogram[..., 1] - delta_y, parallelogram[..., 3])
496-
out_boxes[..., 6] = torch.where(mask, parallelogram[..., 4] - delta_x, parallelogram[..., 6])
497-
out_boxes[..., 7] = torch.where(mask, parallelogram[..., 5] + delta_y, parallelogram[..., 7])
498-
499-
# Keeping the points (x2, y2) and (x4, y4) unchanged.
500-
out_boxes[..., 0] = torch.where(~mask, parallelogram[..., 2] - delta_x, parallelogram[..., 0])
501-
out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1])
502-
out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4])
503-
out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5])
454+
original_shape = parallelogram.shape
455+
dtype = parallelogram.dtype
456+
acceptable_dtypes = [torch.float32, torch.float64]
457+
need_cast = dtype not in acceptable_dtypes
458+
if need_cast:
459+
# Up-case to avoid overflow for square operations
460+
parallelogram = parallelogram.to(torch.float32)
461+
462+
x1, y1, x2, y2, x3, y3, x4, y4 = parallelogram.unbind(-1)
463+
cx = (x1 + x3) / 2
464+
cy = (y1 + y3) / 2
465+
466+
# Calculate width, height, and rotation angle of the parallelogram
467+
wp = torch.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
468+
hp = torch.sqrt((x4 - x1) ** 2 + (y4 - y1) ** 2)
469+
r12 = torch.atan2(y1 - y2, x2 - x1)
470+
r14 = torch.atan2(y1 - y4, x4 - x1)
471+
r_rad = r12 - r14
472+
sign = torch.where(r_rad > torch.pi / 2, -1, 1)
473+
cos, sin = r_rad.cos(), r_rad.sin()
474+
475+
# Calculate width, height, and rotation angle of the rectangle
476+
w = torch.where(wp < hp, wp * sin, wp + hp * cos * sign)
477+
h = torch.where(wp > hp, hp * sin, hp + wp * cos * sign)
478+
r_rad = torch.where(hp > wp, r14 + torch.pi / 2, r12)
479+
cos, sin = r_rad.cos(), r_rad.sin()
480+
481+
out_boxes = convert_bounding_box_format(
482+
torch.stack((cx, cy, w, h, r_rad * 180 / torch.pi), dim=-1),
483+
old_format=tv_tensors.BoundingBoxFormat.CXCYWHR,
484+
new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY,
485+
inplace=False,
486+
).reshape(original_shape)
487+
488+
if need_cast:
489+
out_boxes = out_boxes.to(dtype)
504490
return out_boxes
505491

506492

0 commit comments

Comments
 (0)