Skip to content

Commit 207fe96

Browse files
Add rotate transformation tests for rotated boxes
Test Plan: Unit tests: ```bash pytest test/test_transforms_v2.py -vvv -k "TestRotate and test_kernel_bounding_boxes" pytest test/test_transforms_v2.py -vvv -k "TestRotate and test_functional_bounding_boxes_correctness" pytest test/test_transforms_v2.py -vvv -k "TestRotate and test_transform_bounding_boxes_correctness" ```
1 parent 8dc9ce4 commit 207fe96

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

test/test_transforms_v2.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,7 +1889,7 @@ def test_kernel_image(self, param, value, dtype, device):
18891889
expand=[False, True],
18901890
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
18911891
)
1892-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1892+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
18931893
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
18941894
@pytest.mark.parametrize("device", cpu_and_cuda())
18951895
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
@@ -2025,6 +2025,13 @@ def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
20252025
x, y = recenter_xy
20262026
if bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXY:
20272027
translate = [x, y, x, y]
2028+
elif bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
2029+
translate = [x, y, x, y, x, y, x, y]
2030+
elif (
2031+
bounding_boxes.format is tv_tensors.BoundingBoxFormat.CXCYWHR
2032+
or bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYWHR
2033+
):
2034+
translate = [x, y, 0.0, 0.0, 0.0]
20282035
else:
20292036
translate = [x, y, 0.0, 0.0]
20302037
return tv_tensors.wrap(
@@ -2049,7 +2056,12 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
20492056
expand=expand, canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix
20502057
)
20512058

2052-
output = reference_affine_bounding_boxes_helper(
2059+
helper = (
2060+
reference_affine_rotated_bounding_boxes_helper
2061+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
2062+
else reference_affine_bounding_boxes_helper
2063+
)
2064+
output = helper(
20532065
bounding_boxes,
20542066
affine_matrix=affine_matrix,
20552067
new_canvas_size=new_canvas_size,
@@ -2060,7 +2072,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
20602072
bounding_boxes
20612073
)
20622074

2063-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
2075+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
20642076
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
20652077
@pytest.mark.parametrize("expand", [False, True])
20662078
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@@ -2073,7 +2085,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent
20732085
torch.testing.assert_close(actual, expected)
20742086
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
20752087

2076-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
2088+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
20772089
@pytest.mark.parametrize("expand", [False, True])
20782090
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
20792091
@pytest.mark.parametrize("seed", list(range(5)))

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ def _affine_bounding_boxes_with_expand(
983983
new_points = torch.matmul(points, transposed_affine_matrix)
984984
tr = torch.amin(new_points, dim=0, keepdim=True)
985985
# Translate bounding boxes
986-
out_bboxes.sub_(tr.repeat((1, 2)))
986+
out_bboxes.sub_(tr.repeat((1, 4 if is_rotated else 2)))
987987
# Estimate meta-data for image with inverted=True
988988
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
989989
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)

0 commit comments

Comments
 (0)