Skip to content

Commit 3966ff2

Browse files
Update elastic_bounding_boxes for rotate boxes
Test Plan: Unit tests: ```bash pytest test/test_transforms_v2.py -vvv -k "TestElastic and test_kernel_bounding_boxes" ```
1 parent 1105aa1 commit 3966ff2

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

test/test_transforms_v2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3006,11 +3006,21 @@ def test_kernel_image(self, param, value, dtype, device):
30063006
check_cuda_vs_cpu=dtype is not torch.float16,
30073007
)
30083008

3009-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
3009+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
30103010
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
30113011
@pytest.mark.parametrize("device", cpu_and_cuda())
30123012
def test_kernel_bounding_boxes(self, format, dtype, device):
30133013
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
3014+
if tv_tensors.is_rotated_bounding_format(format):
3015+
# generated test rotated boxes can be out of the canvas size
3016+
# but elastic transformation expect the boxes to be clamped
3017+
# Also by convention the integer boxes should be allowed to
3018+
# reach width and height. But the grid for the elastic transform
3019+
# only covers up to width - 1 and height -1. So we are tricking the
3020+
# test by making sure we are clamping the boxes up to width - 1 and height -1.
3021+
bounding_boxes.canvas_size = (bounding_boxes.canvas_size[0] - 1, bounding_boxes.canvas_size[1] - 1)
3022+
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes)
3023+
bounding_boxes.canvas_size = (bounding_boxes.canvas_size[0] + 1, bounding_boxes.canvas_size[1] + 1)
30143024

30153025
check_kernel(
30163026
F.elastic_bounding_boxes,

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,23 +2017,27 @@ def elastic_bounding_boxes(
20172017
# TODO: add in docstring about approximation we are doing for grid inversion
20182018
device = bounding_boxes.device
20192019
dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
2020+
is_rotated = tv_tensors.is_rotated_bounding_format(format)
20202021

20212022
if displacement.dtype != dtype or displacement.device != device:
20222023
displacement = displacement.to(dtype=dtype, device=device)
20232024

20242025
original_shape = bounding_boxes.shape
20252026
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
2027+
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
2028+
20262029
bounding_boxes = (
2027-
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
2028-
).reshape(-1, 4)
2030+
convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format)
2031+
).reshape(-1, 8 if is_rotated else 4)
20292032

20302033
id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
20312034
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
20322035
# This is not an exact inverse of the grid
20332036
inv_grid = id_grid.sub_(displacement)
20342037

20352038
# Get points from bboxes
2036-
points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
2039+
points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
2040+
points = points.reshape(-1, 2)
20372041
if points.is_floating_point():
20382042
points = points.ceil_()
20392043
index_xy = points.to(dtype=torch.long)
@@ -2043,16 +2047,22 @@ def elastic_bounding_boxes(
20432047
t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
20442048
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
20452049

2046-
transformed_points = transformed_points.reshape(-1, 4, 2)
2047-
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
2050+
if is_rotated:
2051+
transformed_points = transformed_points.reshape(-1, 8)
2052+
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype)
2053+
else:
2054+
transformed_points = transformed_points.reshape(-1, 4, 2)
2055+
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
2056+
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype)
2057+
20482058
out_bboxes = clamp_bounding_boxes(
2049-
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
2050-
format=tv_tensors.BoundingBoxFormat.XYXY,
2059+
out_bboxes,
2060+
format=intermediate_format,
20512061
canvas_size=canvas_size,
20522062
)
20532063

20542064
return convert_bounding_box_format(
2055-
out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
2065+
out_bboxes, old_format=intermediate_format, new_format=format, inplace=False
20562066
).reshape(original_shape)
20572067

20582068

0 commit comments

Comments
 (0)