Skip to content

Commit 9827ab6

Browse files
Add utility function to identify rotated box formats
Test Plan: Run unit tests:`pytest test/test_tv_tensors.py -vvv -k "test_bbox_format"`
1 parent 734aed2 commit 9827ab6

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

test/test_tv_tensors.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,72 @@ def test_bbox_instance(data, format):
4343
assert bboxes.format == format
4444

4545

46+
@pytest.mark.parametrize(
47+
"format",
48+
[
49+
"XYXY",
50+
"XYWH",
51+
"CXCYWH",
52+
"XYXYXYXY",
53+
"XYWHR",
54+
"CXCYWHR",
55+
tv_tensors.BoundingBoxFormat.XYXY,
56+
tv_tensors.BoundingBoxFormat.XYWH,
57+
tv_tensors.BoundingBoxFormat.CXCYWH,
58+
tv_tensors.BoundingBoxFormat.XYXYXYXY,
59+
tv_tensors.BoundingBoxFormat.XYWHR,
60+
tv_tensors.BoundingBoxFormat.CXCYWHR,
61+
],
62+
)
63+
def test_bbox_format(format):
64+
if isinstance(format, str):
65+
format = tv_tensors.BoundingBoxFormat[(format.upper())]
66+
if format == tv_tensors.BoundingBoxFormat.XYXYXYXY:
67+
assert tv_tensors.is_rotated_bounding_format(format) is True
68+
elif format == tv_tensors.BoundingBoxFormat.XYWHR:
69+
assert tv_tensors.is_rotated_bounding_format(format) is True
70+
elif format == tv_tensors.BoundingBoxFormat.CXCYWHR:
71+
assert tv_tensors.is_rotated_bounding_format(format) is True
72+
else:
73+
assert tv_tensors.is_rotated_bounding_format(format) is False
74+
75+
76+
@pytest.mark.parametrize(
77+
"format",
78+
[
79+
"XYXY",
80+
"XYWH",
81+
"CXCYWH",
82+
"XYXYXYXY",
83+
"XYWHR",
84+
"CXCYWHR",
85+
tv_tensors.BoundingBoxFormat.XYXY,
86+
tv_tensors.BoundingBoxFormat.XYWH,
87+
tv_tensors.BoundingBoxFormat.CXCYWH,
88+
tv_tensors.BoundingBoxFormat.XYXYXYXY,
89+
tv_tensors.BoundingBoxFormat.XYWHR,
90+
tv_tensors.BoundingBoxFormat.CXCYWHR,
91+
],
92+
)
93+
def test_bbox_format_scripted(format):
94+
obj = tv_tensors.is_rotated_bounding_format
95+
try:
96+
fn = torch.jit.script(obj)
97+
except Exception as error:
98+
name = getattr(obj, "__name__", obj.__class__.__name__)
99+
raise AssertionError(f"Trying to `torch.jit.script` `{name}` raised the error above.") from error
100+
if isinstance(format, str):
101+
format = tv_tensors.BoundingBoxFormat[(format.upper())]
102+
if format == tv_tensors.BoundingBoxFormat.XYXYXYXY:
103+
assert fn(format) is True
104+
elif format == tv_tensors.BoundingBoxFormat.XYWHR:
105+
assert fn(format) is True
106+
elif format == tv_tensors.BoundingBoxFormat.CXCYWHR:
107+
assert fn(format) is True
108+
else:
109+
assert fn(format) is False
110+
111+
46112
def test_bbox_dim_error():
47113
data_3d = [[[1, 2, 3, 4]]]
48114
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):

torchvision/tv_tensors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat
3+
from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format
44
from ._image import Image
55
from ._mask import Mask
66
from ._torch_function_helpers import set_return_type

torchvision/tv_tensors/_bounding_boxes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ class BoundingBoxFormat(Enum):
3838
XYXYXYXY = "XYXYXYXY"
3939

4040

41+
# TODO: Once torchscript supports Enums with staticmethod
42+
# this can be put into BoundingBoxFormat as staticmethod
43+
def is_rotated_bounding_format(format: BoundingBoxFormat) -> bool:
44+
return (
45+
format == BoundingBoxFormat.XYWHR or format == BoundingBoxFormat.CXCYWHR or format == BoundingBoxFormat.XYXYXYXY
46+
)
47+
48+
4149
class BoundingBoxes(TVTensor):
4250
""":class:`torch.Tensor` subclass for bounding boxes with shape ``[N, K]``.
4351

0 commit comments

Comments
 (0)