Skip to content

Commit a28fa39

Browse files
Update _affine_bounding_boxes_with_expand for rotated boxes
Test Plan: Run unit tests: ```bash pytest test/test_transforms_v2.py -vvv -k "TestAffine and test_kernel_bounding_boxes" pytest test/test_transforms_v2.py -vvv -k "TestAffine and test_functional_bounding_boxes_correctness" ```
1 parent 782f406 commit a28fa39

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

test/test_transforms_v2.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,7 +1269,7 @@ def test_kernel_image(self, param, value, dtype, device):
12691269
shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
12701270
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
12711271
)
1272-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1272+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
12731273
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
12741274
@pytest.mark.parametrize("device", cpu_and_cuda())
12751275
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
@@ -1411,14 +1411,22 @@ def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate,
14111411
if center is None:
14121412
center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
14131413

1414-
return reference_affine_bounding_boxes_helper(
1414+
affine_matrix = self._compute_affine_matrix(
1415+
angle=angle, translate=translate, scale=scale, shear=shear, center=center
1416+
)
1417+
1418+
helper = (
1419+
reference_affine_rotated_bounding_boxes_helper
1420+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
1421+
else reference_affine_bounding_boxes_helper
1422+
)
1423+
1424+
return helper(
14151425
bounding_boxes,
1416-
affine_matrix=self._compute_affine_matrix(
1417-
angle=angle, translate=translate, scale=scale, shear=shear, center=center
1418-
),
1426+
affine_matrix=affine_matrix,
14191427
)
14201428

1421-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1429+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
14221430
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
14231431
@pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
14241432
@pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -893,11 +893,15 @@ def _affine_bounding_boxes_with_expand(
893893
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
894894
dtype = bounding_boxes.dtype
895895
device = bounding_boxes.device
896+
intermediate_format = (
897+
tv_tensors.BoundingBoxFormat.XYXYXYXY
898+
if tv_tensors.is_rotated_bounding_format(format)
899+
else tv_tensors.BoundingBoxFormat.XYXY
900+
)
901+
intermediate_shape = 8 if tv_tensors.is_rotated_bounding_format(format) else 4
896902
bounding_boxes = (
897-
convert_bounding_box_format(
898-
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
899-
)
900-
).reshape(-1, 4)
903+
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format, inplace=True)
904+
).reshape(-1, intermediate_shape)
901905

902906
angle, translate, shear, center = _affine_parse_args(
903907
angle, translate, scale, shear, InterpolationMode.NEAREST, center
@@ -921,15 +925,22 @@ def _affine_bounding_boxes_with_expand(
921925
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
922926
# Single point structure is similar to
923927
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
924-
points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
928+
if tv_tensors.is_rotated_bounding_format(format):
929+
points = bounding_boxes.reshape(-1, 2)
930+
else:
931+
points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
925932
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
926933
# 2) Now let's transform the points using affine matrix
927934
transformed_points = torch.matmul(points, transposed_affine_matrix)
928935
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
929936
# and compute bounding box from 4 transformed points:
930-
transformed_points = transformed_points.reshape(-1, 4, 2)
931-
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
932-
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
937+
if tv_tensors.is_rotated_bounding_format(format):
938+
transformed_points = transformed_points.reshape(-1, 8)
939+
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points)
940+
else:
941+
transformed_points = transformed_points.reshape(-1, 4, 2)
942+
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
943+
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
933944

934945
if expand:
935946
# Compute minimum point for transformed image frame:
@@ -954,9 +965,9 @@ def _affine_bounding_boxes_with_expand(
954965
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
955966
canvas_size = (new_height, new_width)
956967

957-
out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
968+
out_bboxes = clamp_bounding_boxes(out_bboxes, format=intermediate_format, canvas_size=canvas_size)
958969
out_bboxes = convert_bounding_box_format(
959-
out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
970+
out_bboxes, old_format=intermediate_format, new_format=format, inplace=True
960971
).reshape(original_shape)
961972

962973
out_bboxes = out_bboxes.to(original_dtype)

0 commit comments

Comments
 (0)