Skip to content

Commit 2a361ef

Browse files
Adjust rotated clamping conditions
Test Plan: ```bash pytest test/test_transforms_v2.py -k box -v ```
1 parent d5df0d6 commit 2a361ef

File tree

4 files changed

+47
-42
lines changed

4 files changed

+47
-42
lines changed

test/common_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,9 @@ def sample_position(values, max_value):
469469
raise ValueError(f"Format {format} is not supported")
470470
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
471471
if tv_tensors.is_rotated_bounding_format(format):
472-
# The rotated bounding boxes are not guaranteed to be within the canvas by design,
473-
# so we apply clamping. We also add a 2 buffer to the canvas size to avoid
474-
# numerical issues during the testing
472+
# Rotated bounding boxes are not inherently confined within the canvas, so clamping is applied.
473+
# Transform tests allow a 2-pixel tolerance relative to the canvas size.
474+
# To prevent discrepancies when clamping with different canvas sizes, we add a 2-pixel buffer.
475475
buffer = 4
476476
out_boxes = clamp_bounding_boxes(
477477
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer)

test/test_transforms_v2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4421,9 +4421,15 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
44214421
else reference_affine_bounding_boxes_helper
44224422
)
44234423

4424+
bounding_boxes = helper(
4425+
bounding_boxes,
4426+
affine_matrix=crop_affine_matrix,
4427+
new_canvas_size=(height, width)
4428+
)
4429+
44244430
return helper(
44254431
bounding_boxes,
4426-
affine_matrix=affine_matrix,
4432+
affine_matrix=resize_affine_matrix,
44274433
new_canvas_size=size,
44284434
)
44294435

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,8 +1104,9 @@ def _affine_bounding_boxes_with_expand(
11041104

11051105
original_shape = bounding_boxes.shape
11061106
dtype = bounding_boxes.dtype
1107-
need_cast = not bounding_boxes.is_floating_point()
1108-
bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone()
1107+
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
1108+
need_cast = dtype not in acceptable_dtypes
1109+
bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone()
11091110
device = bounding_boxes.device
11101111
is_rotated = tv_tensors.is_rotated_bounding_format(format)
11111112
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
@@ -2397,19 +2398,19 @@ def elastic_bounding_boxes(
23972398

23982399
original_shape = bounding_boxes.shape
23992400
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
2400-
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
2401+
intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
24012402

24022403
bounding_boxes = (
24032404
convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format)
2404-
).reshape(-1, 8 if is_rotated else 4)
2405+
).reshape(-1, 5 if is_rotated else 4)
24052406

24062407
id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
24072408
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
24082409
# This is not an exact inverse of the grid
24092410
inv_grid = id_grid.sub_(displacement)
24102411

24112412
# Get points from bboxes
2412-
points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
2413+
points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
24132414
points = points.reshape(-1, 2)
24142415
if points.is_floating_point():
24152416
points = points.ceil_()
@@ -2421,8 +2422,8 @@ def elastic_bounding_boxes(
24212422
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
24222423

24232424
if is_rotated:
2424-
transformed_points = transformed_points.reshape(-1, 8)
2425-
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype)
2425+
transformed_points = transformed_points.reshape(-1, 2)
2426+
out_bboxes = torch.cat([transformed_points, bounding_boxes[:, 2:]], dim=1).to(bounding_boxes.dtype)
24262427
else:
24272428
transformed_points = transformed_points.reshape(-1, 4, 2)
24282429
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

torchvision/transforms/v2/functional/_meta.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -409,23 +409,17 @@ def _order_bounding_boxes_points(
409409
if indices is None:
410410
output_xyxyxyxy = bounding_boxes.reshape(-1, 8)
411411
x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2]
412-
y_max = torch.max(y, dim=1, keepdim=True)[0]
413-
_, x1 = ((y_max - y) / y_max + (x + 1) * 100).min(dim=1)
412+
y_max = torch.max(y.abs(), dim=1, keepdim=True)[0]
413+
_, x1 = (y / y_max + (x + 1) * 100).min(dim=1)
414414
indices = torch.ones_like(output_xyxyxyxy)
415415
indices[..., 0] = x1.mul(2)
416416
indices.cumsum_(1).remainder_(8)
417417
return indices, bounding_boxes.gather(1, indices.to(torch.int64))
418418

419419

420-
def _area(box: torch.Tensor) -> torch.Tensor:
421-
x1, y1, x2, y2, x3, y3, x4, y4 = box.reshape(-1, 8).unbind(-1)
422-
w = torch.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2)
423-
h = torch.sqrt((y3 - y2) ** 2 + (x3 - x2) ** 2)
424-
return w * h
425-
426-
427420
def _clamp_along_y_axis(
428421
bounding_boxes: torch.Tensor,
422+
canvas_size: tuple[int, int],
429423
) -> torch.Tensor:
430424
"""
431425
Adjusts bounding boxes along the y-axis based on specific conditions.
@@ -448,29 +442,33 @@ def _clamp_along_y_axis(
448442
b2 = y2 + x2 / a
449443
b3 = y3 - a * x3
450444
b4 = y4 + x4 / a
451-
b23 = (b2 - b3) / 2 * a / (1 + a**2)
452-
z = torch.zeros_like(b1)
453-
case_a = torch.cat([x.unsqueeze(1) for x in [z, b1, x2, y2, x3, y3, x3 - x2, y3 + b1 - y2]], dim=1)
454-
case_b = torch.cat([x.unsqueeze(1) for x in [z, b4, x2 - x1, y2 - y1 + b4, x3, y3, x4, y4]], dim=1)
455-
case_c = torch.cat(
456-
[x.unsqueeze(1) for x in [z, (b2 + b3) / 2, b23, -b23 / a + b2, x3, y3, b23, b23 * a + b3]], dim=1
445+
c = a / (1 + a**2)
446+
b1 = b2.clamp(0).clamp(b1, b3)
447+
b4 = b3.clamp(max=canvas_size[0]).clamp(b2, b4)
448+
case_a = torch.stack(
449+
(
450+
(b4 - b1) * c,
451+
(b4 - b1) * c * a + b1,
452+
(b2 - b1) * c,
453+
(b1 - b2) * c / a + b2,
454+
x3,
455+
y3,
456+
(b4 - b3) * c,
457+
(b3 - b4) * c / a + b4,
458+
),
459+
dim=-1,
457460
)
458-
case_d = torch.zeros_like(case_c)
459-
case_e = torch.cat([x.unsqueeze(1) for x in [x1.clamp(0), y1, x2.clamp(0), y2, x3, y3, x4, y4]], dim=1)
460-
461-
cond_a = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)
462-
cond_a = cond_a.logical_and(_area(case_a) > _area(case_b))
463-
cond_a = cond_a.logical_or((x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 <= 0))
464-
cond_b = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)
465-
cond_b = cond_b.logical_and(_area(case_a) <= _area(case_b))
466-
cond_b = cond_b.logical_or((x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 >= 0))
467-
cond_c = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 <= 0)
468-
cond_d = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0)
469-
cond_e = x1.isclose(x2)
470-
461+
case_b = bounding_boxes.clone()
462+
case_b[..., 0].clamp_(0)
463+
case_b[..., 6].clamp_(0)
464+
case_c = torch.zeros_like(case_b)
465+
466+
cond_a = x1 < 0
467+
cond_b = y1.isclose(y2, rtol=1e-05, atol=1e-05)
468+
cond_c = (x1 <= 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0)
471469
for cond, case in zip(
472-
[cond_a, cond_b, cond_c, cond_d, cond_e],
473-
[case_a, case_b, case_c, case_d, case_e],
470+
[cond_a, cond_b, cond_c],
471+
[case_a, case_b, case_c],
474472
):
475473
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
476474
return bounding_boxes.to(original_dtype).reshape(original_shape)
@@ -512,7 +510,7 @@ def _clamp_rotated_bounding_boxes(
512510

513511
for _ in range(4): # Iterate over the 4 vertices.
514512
indices, out_boxes = _order_bounding_boxes_points(out_boxes)
515-
out_boxes = _clamp_along_y_axis(out_boxes)
513+
out_boxes = _clamp_along_y_axis(out_boxes, canvas_size)
516514
_, out_boxes = _order_bounding_boxes_points(out_boxes, indices)
517515
# rotate 90 degrees counter clock wise
518516
out_boxes[:, ::2], out_boxes[:, 1::2] = (

0 commit comments

Comments
 (0)