Skip to content

Commit a7d07dc

Browse files
Add vertical_flip_bounding_boxes
Test Plan: Run unit tests: `pytest test/test_transforms_v2.py -vvv -k "TestVerticalFlip and test_kernel_bounding_boxes"`
1 parent 95ed7cf commit a7d07dc

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

test/test_transforms_v2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,7 +1531,7 @@ class TestVerticalFlip:
15311531
def test_kernel_image(self, dtype, device):
15321532
check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))
15331533

1534-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1534+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
15351535
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
15361536
@pytest.mark.parametrize("device", cpu_and_cuda())
15371537
def test_kernel_bounding_boxes(self, format, dtype, device):
@@ -1588,23 +1588,25 @@ def test_image_correctness(self, fn):
15881588

15891589
torch.testing.assert_close(actual, expected)
15901590

1591-
def _reference_vertical_flip_bounding_boxes(self, bounding_boxes):
1591+
def _reference_vertical_flip_bounding_boxes(self, bounding_boxes, format):
15921592
affine_matrix = np.array(
15931593
[
15941594
[1, 0, 0],
15951595
[0, -1, bounding_boxes.canvas_size[0]],
15961596
],
15971597
)
15981598

1599+
if tv_tensors.is_rotated_bounding_format(format):
1600+
return reference_affine_rotated_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
15991601
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
16001602

1601-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1603+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
16021604
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
16031605
def test_bounding_boxes_correctness(self, format, fn):
16041606
bounding_boxes = make_bounding_boxes(format=format)
16051607

16061608
actual = fn(bounding_boxes)
1607-
expected = self._reference_vertical_flip_bounding_boxes(bounding_boxes)
1609+
expected = self._reference_vertical_flip_bounding_boxes(bounding_boxes, format)
16081610

16091611
torch.testing.assert_close(actual, expected)
16101612

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,32 @@ def vertical_flip_bounding_boxes(
146146
) -> torch.Tensor:
147147
shape = bounding_boxes.shape
148148

149-
bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
149+
if tv_tensors.is_rotated_bounding_format(format):
150+
bounding_boxes = (
151+
bounding_boxes.clone().reshape(-1, 5)
152+
if format != tv_tensors.BoundingBoxFormat.XYXYXYXY
153+
else bounding_boxes.clone().reshape(-1, 8)
154+
)
155+
else:
156+
bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
150157

151158
if format == tv_tensors.BoundingBoxFormat.XYXY:
152159
bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
153160
elif format == tv_tensors.BoundingBoxFormat.XYWH:
154161
bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
155-
else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
162+
elif format == tv_tensors.BoundingBoxFormat.CXCYWH:
156163
bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
164+
elif format == tv_tensors.BoundingBoxFormat.XYXYXYXY:
165+
bounding_boxes[:, 1::2].sub_(canvas_size[0]).neg_()
166+
bounding_boxes = bounding_boxes[:, [0, 1, 6, 7, 4, 5, 2, 3]]
167+
elif format == tv_tensors.BoundingBoxFormat.XYWHR:
168+
bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
169+
bounding_boxes = bounding_boxes[:, [0, 1, 3, 2, 4]]
170+
bounding_boxes[:, -1].sub_(90).neg_()
171+
else: # format == tv_tensors.BoundingBoxFormat.CXCYWHR:
172+
bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
173+
bounding_boxes = bounding_boxes[:, [0, 1, 3, 2, 4]]
174+
bounding_boxes[:, -1].sub_(90).neg_()
157175

158176
return bounding_boxes.reshape(shape)
159177

0 commit comments

Comments
 (0)