Skip to content

Commit a1156f1

Browse files
Adjust _parallelogram_to_bounding_boxes
Test Plan: Run unit tests ```bash pytest test/test_transforms_v2.py -vvv -k "test_parallelogram_to_bounding_boxes" ```
1 parent c96676a commit a1156f1

File tree

2 files changed

+61
-22
lines changed

2 files changed

+61
-22
lines changed

test/test_transforms_v2.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5943,10 +5943,32 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t
59435943
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
59445944
@pytest.mark.parametrize("device", cpu_and_cuda())
59455945
def test_parallelogram_to_bounding_boxes(input_size, dtype, device):
5946-
bounding_boxes = make_bounding_boxes(input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, dtype=dtype, device=device)
5946+
# Assert that applying `_parallelogram_to_bounding_boxes` to rotated boxes
5947+
# does not modify the input.
5948+
bounding_boxes = make_bounding_boxes(
5949+
input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, dtype=dtype, device=device
5950+
)
59475951
actual = _parallelogram_to_bounding_boxes(bounding_boxes)
59485952
torch.testing.assert_close(actual, bounding_boxes, rtol=0, atol=1)
59495953

5954+
# Test the transformation of two simple parallelograms.
5955+
# 1---2 1----2
5956+
# / / -> | |
5957+
# 4---3 4----3
5958+
5959+
# 1---2 1----2
5960+
# \ \ -> | |
5961+
# 4---3 4----3
5962+
parallelogram = torch.tensor([[1, 0, 4, 0, 3, 2, 0, 2], [0, 0, 3, 0, 4, 2, 1, 2]])
5963+
expected = torch.tensor(
5964+
[
5965+
[0, 0, 4, 0, 4, 2, 0, 2],
5966+
[0, 0, 4, 0, 4, 2, 0, 2],
5967+
]
5968+
)
5969+
actual = _parallelogram_to_bounding_boxes(parallelogram)
5970+
assert_equal(actual, expected)
5971+
59505972

59515973
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
59525974
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,10 @@ def _resize_mask_dispatch(
383383

384384
def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tensor:
385385
"""
386-
Convert a parallelogram to a rectangle while keeping the points (x1, y1) and (x3, y3) unchanged.
387-
386+
Convert a parallelogram to a rectangle while keeping two points unchanged.
388387
This function transforms a parallelogram represented by 8 coordinates (4 points) into a rectangle.
389-
The first point (x1, y1) and the third point (x3, y3) of the parallelogram remain fixed,
390-
while the second and fourth points are adjusted to form a proper rectangle.
388+
The two diagonally opposed points of the parallelogram forming the longest diagonal remain fixed.
389+
The other points are adjusted to form a proper rectangle.
391390
392391
Note:
393392
This function is not applied in-place and will return a copy of the input tensor.
@@ -401,34 +400,52 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
401400
The output maintains the same dtype as the input.
402401
"""
403402
dtype = parallelogram.dtype
404-
int_dtype = dtype in (torch.uint8,
405-
torch.int8,
406-
torch.int16,
407-
torch.int32,
408-
torch.int64,
409-
)
403+
int_dtype = dtype in (
404+
torch.uint8,
405+
torch.int8,
406+
torch.int16,
407+
torch.int32,
408+
torch.int64,
409+
)
410410

411-
# Calculate diagonal vector from first to third point
412-
dx = parallelogram[..., 4] - parallelogram[..., 0]
413-
dy = parallelogram[..., 5] - parallelogram[..., 1]
414-
diag = torch.sqrt(dx**2 + dy**2)
411+
out_boxes = parallelogram.clone()
412+
413+
# Calculate parallelogram diagonal vectors
414+
dx13 = parallelogram[..., 4] - parallelogram[..., 0]
415+
dy13 = parallelogram[..., 5] - parallelogram[..., 1]
416+
dx42 = parallelogram[..., 2] - parallelogram[..., 6]
417+
dy42 = parallelogram[..., 3] - parallelogram[..., 7]
418+
diag13 = torch.sqrt(dx13**2 + dy13**2)
419+
diag24 = torch.sqrt(dx42**2 + dy42**2)
420+
mask = diag13 > diag24
415421

416422
# Calculate rotation angle in radians
417423
r_rad = torch.atan2(parallelogram[..., 1] - parallelogram[..., 3], parallelogram[..., 2] - parallelogram[..., 0])
418424
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
419425

420426
# Calculate width using the angle between diagonal and rotation
421-
w = diag * torch.abs(torch.sin(torch.atan2(dx, dy) - r_rad))
427+
w = torch.where(
428+
mask,
429+
diag13 * torch.abs(torch.sin(torch.atan2(dx13, dy13) - r_rad)),
430+
diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad)),
431+
)
422432

423433
delta_x = torch.round(w * cos).to(dtype) if int_dtype else w * cos
424-
detla_y = torch.round(w * sin).to(dtype) if int_dtype else w * sin
434+
delta_y = torch.round(w * sin).to(dtype) if int_dtype else w * sin
425435

426436
# Update coordinates to form a rectangle
427-
parallelogram[..., 2] = parallelogram[..., 0] + delta_x
428-
parallelogram[..., 3] = parallelogram[..., 1] - detla_y
429-
parallelogram[..., 6] = parallelogram[..., 4] - delta_x
430-
parallelogram[..., 7] = parallelogram[..., 5] + detla_y
431-
return parallelogram
437+
# Keeping the points (x1, y1) and (x3, y3) unchanged.
438+
out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2])
439+
out_boxes[..., 3] = torch.where(mask, parallelogram[..., 1] - delta_y, parallelogram[..., 3])
440+
out_boxes[..., 6] = torch.where(mask, parallelogram[..., 4] - delta_x, parallelogram[..., 6])
441+
out_boxes[..., 7] = torch.where(mask, parallelogram[..., 5] + delta_y, parallelogram[..., 7])
442+
443+
# Keeping the points (x2, y2) and (x4, y4) unchanged.
444+
out_boxes[..., 0] = torch.where(~mask, parallelogram[..., 2] - delta_x, parallelogram[..., 0])
445+
out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1])
446+
out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4])
447+
out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5])
448+
return out_boxes
432449

433450

434451
def resize_bounding_boxes(

0 commit comments

Comments
 (0)