Skip to content

Commit e99b82a

Browse files
committed
Improved convert_bounding_boxes_to_points to handle rotated bounding boxes and added tests for all formats
1 parent 4b62ef4 commit e99b82a

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

test/test_transforms_v2.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6877,18 +6877,39 @@ def test_no_valid_input(self, query):
68776877
query(["blah"])
68786878

68796879
@pytest.mark.parametrize(
6880-
"boxes", [tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))]
6880+
"boxes", [
6881+
tv_tensors.BoundingBoxes(torch.tensor([[1., 1., 2., 2.]]), format="XYXY", canvas_size=(4, 4)), # [boxes0]
6882+
tv_tensors.BoundingBoxes(torch.tensor([[1., 1., 1., 1.]]), format="XYWH", canvas_size=(4, 4)), # [boxes1]
6883+
tv_tensors.BoundingBoxes(torch.tensor([[1.5, 1.5, 1., 1.]]), format="CXCYWH", canvas_size=(4, 4)), # [boxes2]
6884+
tv_tensors.BoundingBoxes(torch.tensor([[1.5, 1.5, 1., 1., 45]]), format="CXCYWHR", canvas_size=(4, 4)), # [boxes3]
6885+
tv_tensors.BoundingBoxes(torch.tensor([[1., 1., 1., 1., 45.]]), format="XYWHR", canvas_size=(4, 4)), # [boxes4]
6886+
tv_tensors.BoundingBoxes(torch.tensor([[1., 1., 1., 2., 2., 2., 2., 1.]]), format="XY" * 4, canvas_size=(4, 4)), # [boxes5]
6887+
]
68816888
)
68826889
def test_convert_bounding_boxes_to_points(self, boxes: tv_tensors.BoundingBoxes):
6883-
# TODO: this test can't handle rotated boxes yet
68846890
kp = F.convert_bounding_boxes_to_points(boxes)
6885-
assert kp.shape == boxes.shape + (2,)
6891+
assert kp.shape == (boxes.shape[0], 4, 2)
68866892
assert kp.dtype == boxes.dtype
68876893
# kp is a list of A, B, C, D polygons.
6888-
# If we use A | C, we should get back the XYXY format of bounding box
6889-
reconverted = torch.cat([kp[..., 0, :], kp[..., 2, :]], dim=-1)
6890-
reconverted_bbox = F.convert_bounding_box_format(
6891-
tv_tensors.BoundingBoxes(reconverted, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=kp.canvas_size),
6892-
new_format=boxes.format,
6893-
)
6894-
assert (reconverted_bbox == boxes).all(), f"Invalid reconversion : {reconverted_bbox}"
6894+
6895+
if F._meta.is_rotated_bounding_box_format(boxes.format):
6896+
# In the rotated case
6897+
# If we convert to XYXYXYXY format, we should get what we want.
6898+
reconverted = kp.reshape(-1, 8)
6899+
reconverted_bbox = F.convert_bounding_box_format(
6900+
tv_tensors.BoundingBoxes(reconverted, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, canvas_size=kp.canvas_size),
6901+
new_format=boxes.format
6902+
)
6903+
assert ((reconverted_bbox - boxes).abs() < 1e-5).all(), ( # Rotational computations mean that we can't ensure exactitude.
6904+
f"Invalid reconversion :\n\tGot: {reconverted_bbox}\n\tFrom: {boxes}\n\t"
6905+
f"Diff: {reconverted_bbox - boxes}"
6906+
)
6907+
else:
6908+
# In the unrotated case
6909+
# If we use A | C, we should get back the XYXY format of bounding box
6910+
reconverted = torch.cat([kp[..., 0, :], kp[..., 2, :]], dim=-1)
6911+
reconverted_bbox = F.convert_bounding_box_format(
6912+
tv_tensors.BoundingBoxes(reconverted, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=kp.canvas_size),
6913+
new_format=boxes.format,
6914+
)
6915+
assert (reconverted_bbox == boxes).all(), f"Invalid reconversion :\n\tGot: {reconverted_bbox}\n\tFrom: {boxes}"

torchvision/transforms/v2/functional/_meta.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,38 @@ def _xyxy_to_keypoints(bounding_boxes: torch.Tensor) -> torch.Tensor:
185185
return bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
186186

187187

188+
def _xyxyxyxy_to_keypoints(bounding_boxes: torch.Tensor) -> torch.Tensor:
189+
return bounding_boxes[:, [[0, 1], [2, 3], [4, 5], [6, 7]]]
190+
191+
188192
def convert_bounding_boxes_to_points(bounding_boxes: tv_tensors.BoundingBoxes) -> tv_tensors.KeyPoints:
189193
"""Converts a set of bounding boxes to its edge points.
190194
195+
.. note::
196+
197+
This handles rotated :class:`tv_tensors.BoundingBoxes` formats
198+
by first converting them to XYXYXYXY format.
199+
200+
Due to floating-point approximation, this may not be an exact computation.
201+
191202
Args:
192203
bounding_boxes (tv_tensors.BoundingBoxes): A set of ``N`` bounding boxes (of shape ``[N, 4]``)
193204
194205
Returns:
195-
tv_tensors.KeyPoints: The edges, of shape ``[N, 4, 2]``
206+
tv_tensors.KeyPoints: The edges, as a polygon of shape ``[N, 4, 2]``
196207
"""
197-
# TODO: support rotated BBOX
208+
if is_rotated_bounding_box_format(bounding_boxes.format):
209+
# We are working on a rotated bounding box
210+
bbox = _convert_bounding_box_format(
211+
bounding_boxes.as_subclass(torch.Tensor),
212+
old_format=bounding_boxes.format,
213+
new_format=BoundingBoxFormat.XYXYXYXY,
214+
inplace=False,
215+
)
216+
return tv_tensors.KeyPoints(
217+
_xyxyxyxy_to_keypoints(bbox), canvas_size=bounding_boxes.canvas_size
218+
)
219+
198220
bbox = _convert_bounding_box_format(
199221
bounding_boxes.as_subclass(torch.Tensor),
200222
old_format=bounding_boxes.format,

0 commit comments

Comments
 (0)