Skip to content

Commit 95ed7cf

Browse files
Add horizontal_flip_rotated_bounding_boxes
Test Plan: Run unit tests: `pytest test/test_transforms_v2.py -vvv -k "TestHorizontalFlip and test_kernel_bounding_boxes"` and `pytest test/test_transforms_v2.py -vvv -k "TestHorizontalFlip and test_bounding_boxes_correctness"`
1 parent 87a238c commit 95ed7cf

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

test/test_transforms_v2.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,71 @@ def affine_bounding_boxes(bounding_boxes):
560560
)
561561

562562

563+
def reference_affine_rotated_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
564+
format = bounding_boxes.format
565+
canvas_size = new_canvas_size or bounding_boxes.canvas_size
566+
567+
def affine_rotated_bounding_boxes(bounding_boxes):
568+
dtype = bounding_boxes.dtype
569+
device = bounding_boxes.device
570+
571+
# Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1
572+
input_xyxyxyxy = F.convert_bounding_box_format(
573+
bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
574+
old_format=format,
575+
new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY,
576+
inplace=True,
577+
)
578+
x1, y1, x3, y3, x2, y2, x4, y4 = input_xyxyxyxy.squeeze(0).tolist()
579+
580+
points = np.array(
581+
[
582+
[x1, y1, 1.0],
583+
[x3, y3, 1.0],
584+
[x2, y2, 1.0],
585+
[x4, y4, 1.0],
586+
]
587+
)
588+
transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T)
589+
output = torch.Tensor(
590+
[
591+
float(transformed_points[0, 0]),
592+
float(transformed_points[0, 1]),
593+
float(transformed_points[3, 0]),
594+
float(transformed_points[3, 1]),
595+
float(transformed_points[2, 0]),
596+
float(transformed_points[2, 1]),
597+
float(transformed_points[1, 0]),
598+
float(transformed_points[1, 1]),
599+
]
600+
)
601+
602+
output = F.convert_bounding_box_format(
603+
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
604+
)
605+
606+
if clamp:
607+
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
608+
output = F.clamp_bounding_boxes(
609+
output,
610+
format=format,
611+
canvas_size=canvas_size,
612+
)
613+
else:
614+
# We leave the bounding box as float64 so the caller gets the full precision to perform any additional
615+
# operation
616+
dtype = output.dtype
617+
618+
return output.to(dtype=dtype, device=device)
619+
620+
return tv_tensors.BoundingBoxes(
621+
torch.cat([affine_rotated_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 5 if format != tv_tensors.BoundingBoxFormat.XYXYXYXY else 8).unbind()], dim=0).reshape(
622+
bounding_boxes.shape
623+
),
624+
format=format,
625+
canvas_size=canvas_size,
626+
)
627+
563628
class TestResize:
564629
INPUT_SIZE = (17, 11)
565630
OUTPUT_SIZES = [17, [17], (17,), None, [12, 13], (12, 13)]
@@ -1012,7 +1077,7 @@ class TestHorizontalFlip:
10121077
def test_kernel_image(self, dtype, device):
10131078
check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))
10141079

1015-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1080+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
10161081
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
10171082
@pytest.mark.parametrize("device", cpu_and_cuda())
10181083
def test_kernel_bounding_boxes(self, format, dtype, device):
@@ -1071,25 +1136,27 @@ def test_image_correctness(self, fn):
10711136

10721137
torch.testing.assert_close(actual, expected)
10731138

1074-
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):
1139+
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes, format):
10751140
affine_matrix = np.array(
10761141
[
10771142
[-1, 0, bounding_boxes.canvas_size[1]],
10781143
[0, 1, 0],
10791144
],
10801145
)
10811146

1147+
if tv_tensors.is_rotated_bounding_format(format):
1148+
return reference_affine_rotated_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
10821149
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
10831150

1084-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1151+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
10851152
@pytest.mark.parametrize(
10861153
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
10871154
)
10881155
def test_bounding_boxes_correctness(self, format, fn):
10891156
bounding_boxes = make_bounding_boxes(format=format)
10901157

10911158
actual = fn(bounding_boxes)
1092-
expected = self._reference_horizontal_flip_bounding_boxes(bounding_boxes)
1159+
expected = self._reference_horizontal_flip_bounding_boxes(bounding_boxes, format)
10931160

10941161
torch.testing.assert_close(actual, expected)
10951162

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,32 @@ def horizontal_flip_bounding_boxes(
7171
) -> torch.Tensor:
7272
shape = bounding_boxes.shape
7373

74-
bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
74+
if tv_tensors.is_rotated_bounding_format(format):
75+
bounding_boxes = (
76+
bounding_boxes.clone().reshape(-1, 5)
77+
if format != tv_tensors.BoundingBoxFormat.XYXYXYXY
78+
else bounding_boxes.clone().reshape(-1, 8)
79+
)
80+
else:
81+
bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
7582

7683
if format == tv_tensors.BoundingBoxFormat.XYXY:
7784
bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
7885
elif format == tv_tensors.BoundingBoxFormat.XYWH:
7986
bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
80-
else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
87+
elif format == tv_tensors.BoundingBoxFormat.CXCYWH:
88+
bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
89+
elif format == tv_tensors.BoundingBoxFormat.XYXYXYXY:
90+
bounding_boxes[:, 0::2].sub_(canvas_size[1]).neg_()
91+
bounding_boxes = bounding_boxes[:, [0, 1, 6, 7, 4, 5, 2, 3]]
92+
elif format == tv_tensors.BoundingBoxFormat.XYWHR:
93+
bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
94+
bounding_boxes = bounding_boxes[:, [0, 1, 3, 2, 4]]
95+
bounding_boxes[:, -1].add_(90).neg_()
96+
else: # format == tv_tensors.BoundingBoxFormat.CXCYWHR:
8197
bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
98+
bounding_boxes = bounding_boxes[:, [0, 1, 3, 2, 4]]
99+
bounding_boxes[:, -1].add_(90).neg_()
82100

83101
return bounding_boxes.reshape(shape)
84102

0 commit comments

Comments
 (0)