Skip to content

Commit 1105aa1

Browse files
Update clamp_bounding_boxes for rotated boxes
Test Plan: Unit tests: ```bash pytest test/test_transforms_v2.py -vvv -k "TestClampBoundingBoxes and test_kernel" pytest test/test_transforms_v2.py -vvv -k "TestClampBoundingBoxes and test_functional" ```
1 parent 207fe96 commit 1105aa1

File tree

2 files changed

+114
-6
lines changed

2 files changed

+114
-6
lines changed

test/test_transforms_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4770,7 +4770,7 @@ def test_correctness_image(self, mean, std, dtype, fn):
47704770

47714771

47724772
class TestClampBoundingBoxes:
4773-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
4773+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
47744774
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
47754775
@pytest.mark.parametrize("device", cpu_and_cuda())
47764776
def test_kernel(self, format, dtype, device):
@@ -4782,7 +4782,7 @@ def test_kernel(self, format, dtype, device):
47824782
canvas_size=bounding_boxes.canvas_size,
47834783
)
47844784

4785-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
4785+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
47864786
def test_functional(self, format):
47874787
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))
47884788

torchvision/transforms/v2/functional/_meta.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,122 @@ def _clamp_bounding_boxes(
352352
return out_boxes.to(in_dtype)
353353

354354

355+
def _order_bounding_boxes_points(
356+
bounding_boxes: torch.Tensor, indices: torch.Tensor | None = None
357+
) -> tuple[torch.Tensor, torch.Tensor]:
358+
"""Re-order points in bounding boxes based on specific criteria or provided indices.
359+
360+
This function reorders the points of bounding boxes either according to provided indices or
361+
by a default ordering strategy. In the default strategy, (x1, y1) corresponds to the point
362+
with the lowest x value. If multiple points have the same lowest x value, the point with the
363+
lowest y value is chosen.
364+
365+
Args:
366+
bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates in format [x1, y1, x2, y2, x3, y3, x4, y4].
367+
indices (torch.Tensor | None): Optional tensor containing indices for reordering. If None, default ordering is applied.
368+
369+
Returns:
370+
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
371+
- indices: The indices used for reordering
372+
- reordered_boxes: The bounding boxes with reordered points
373+
"""
374+
if indices is None:
375+
output_xyxyxyxy = bounding_boxes.clone().reshape(-1, 8)
376+
x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2]
377+
y_max = torch.max(y, dim=1, keepdim=True)[0]
378+
_, x1 = (y_max - y).div(y_max).add(x.add(1).mul(100)).min(dim=1)
379+
indices = torch.ones_like(output_xyxyxyxy)
380+
indices[..., 0] = x1.mul(2)
381+
indices.cumsum_(1).remainder_(8)
382+
return indices, bounding_boxes.gather(1, indices.to(torch.int64))
383+
384+
385+
def area(box: torch.Tensor) -> torch.Tensor:
386+
x1, y1, x2, y2, x3, y3, x4, y4 = box.clone().reshape(-1, 8).unbind(-1)
387+
w = (y2 - y1) ** 2 + (x2 - x1) ** 2
388+
h = (y3 - y2) ** 2 + (x3 - x2) ** 2
389+
return w * h
390+
391+
392+
def _clamp_along_y_axis(
393+
bounding_boxes: torch.Tensor,
394+
) -> torch.Tensor:
395+
"""
396+
Adjusts bounding boxes along the y-axis based on specific conditions.
397+
398+
This function modifies the bounding boxes by evaluating different cases
399+
and applying the appropriate transformation to ensure the bounding boxes
400+
are clamped correctly along the y-axis.
401+
402+
Args:
403+
bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates.
404+
405+
Returns:
406+
torch.Tensor: The adjusted bounding boxes.
407+
"""
408+
original_dtype = bounding_boxes.dtype
409+
original_shape = bounding_boxes.shape
410+
x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.reshape(-1, 8).unbind(-1)
411+
a = (y2 - y1) / (x2 - x1)
412+
b1 = y1 - a * x1
413+
b2 = y2 + x2 / a
414+
b3 = y3 - a * x3
415+
b4 = y4 + x4 / a
416+
b23 = (b2 - b3) / 2 * a / (1 + a**2)
417+
z = torch.zeros_like(b1)
418+
case_a = torch.cat([x.unsqueeze(1) for x in [z, b1, x2, y2, x3, y3, x3 - x2, y3 + b1 - y2]], dim=1)
419+
case_b = torch.cat([x.unsqueeze(1) for x in [z, b4, x2 - x1, y2 - y1 + b4, x3, y3, x4, y4]], dim=1)
420+
case_c = torch.cat(
421+
[x.unsqueeze(1) for x in [z, (b2 + b3) / 2, b23, -b23 / a + b2, x3, y3, b23, b23 * a + b3]], dim=1
422+
)
423+
case_d = torch.zeros_like(case_c)
424+
425+
cond_a = x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0))
426+
cond_a = cond_a.logical_and(area(case_a) > area(case_b))
427+
cond_a = cond_a.logical_or(x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.le(0)))
428+
cond_b = x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0))
429+
cond_b = cond_b.logical_and(area(case_a) <= area(case_b))
430+
cond_b = cond_b.logical_or(x1.lt(0).logical_and(x2.le(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0)))
431+
cond_c = x1.lt(0).logical_and(x2.le(0)).logical_and(x3.ge(0)).logical_and(x4.le(0))
432+
cond_d = x1.lt(0).logical_and(x2.le(0)).logical_and(x3.le(0)).logical_and(x4.le(0))
433+
434+
for cond, case in zip(
435+
[cond_a, cond_b, cond_c, cond_d],
436+
[case_a, case_b, case_c, case_d],
437+
):
438+
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
439+
return bounding_boxes.to(original_dtype).reshape(original_shape)
440+
441+
355442
def _clamp_rotated_bounding_boxes(
356443
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int]
357444
) -> torch.Tensor:
358-
# TODO: For now we are not clamping rotated bounding boxes.
359-
in_dtype = bounding_boxes.dtype
360-
out_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
445+
original_shape = bounding_boxes.shape
446+
original_dtype = bounding_boxes.dtype
447+
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
448+
out_boxes = (
449+
convert_bounding_box_format(
450+
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, inplace=True
451+
)
452+
).reshape(-1, 8)
453+
454+
for _ in range(4):
455+
indices, out_boxes = _order_bounding_boxes_points(out_boxes)
456+
out_boxes = _clamp_along_y_axis(out_boxes)
457+
_, out_boxes = _order_bounding_boxes_points(out_boxes, indices)
458+
# rotate 90 degrees counter clock wise
459+
out_boxes[:, ::2], out_boxes[:, 1::2] = (
460+
out_boxes[:, 1::2].clone(),
461+
canvas_size[1] - out_boxes[:, ::2].clone(),
462+
)
463+
canvas_size = (canvas_size[1], canvas_size[0])
361464

362-
return out_boxes.to(in_dtype)
465+
out_boxes = convert_bounding_box_format(
466+
out_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format, inplace=True
467+
).reshape(original_shape)
468+
469+
out_boxes = out_boxes.to(original_dtype)
470+
return out_boxes
363471

364472

365473
def clamp_bounding_boxes(

0 commit comments

Comments
 (0)