Skip to content

Commit 69f1eef

Browse files
added test
1 parent 49962b5 commit 69f1eef

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

test/test_tv_tensors.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,33 @@ def test_bbox_format(format, is_rotated_expected, scripted):
7777
fn = torch.jit.script(fn)
7878
assert fn(format) == is_rotated_expected
7979

80+
@pytest.mark.parametrize(
81+
"format, support_integer_dtype",
82+
[
83+
("XYXY", True),
84+
("XYWH", True),
85+
("CXCYWH", True),
86+
("XYXYXYXY", False),
87+
("XYWHR", False),
88+
("CXCYWHR", False),
89+
(tv_tensors.BoundingBoxFormat.XYXY, True),
90+
(tv_tensors.BoundingBoxFormat.XYWH, True),
91+
(tv_tensors.BoundingBoxFormat.CXCYWH, True),
92+
(tv_tensors.BoundingBoxFormat.XYXYXYXY, False),
93+
(tv_tensors.BoundingBoxFormat.XYWHR, False),
94+
(tv_tensors.BoundingBoxFormat.CXCYWHR, False),
95+
],
96+
)
97+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
98+
def test_bbox_format_dtype(format, support_integer_dtype, input_dtype):
99+
print(format, support_integer_dtype, input_dtype)
100+
if not input_dtype.is_floating_point and not support_integer_dtype:
101+
pytest.xfail("Rotated bounding boxes should be floating point tensors")
102+
bboxes = tv_tensors.BoundingBoxes(torch.randint(0, 32, size=(5, 2)), format=format, canvas_size=(32, 32))
103+
else:
104+
bboxes = tv_tensors.BoundingBoxes(torch.rand(size=(5, 2)), format=format, canvas_size=(32, 32))
105+
assert isinstance(bboxes, torch.Tensor)
106+
80107

81108
def test_bbox_dim_error():
82109
data_3d = [[[1, 2, 3, 4]]]

0 commit comments

Comments
 (0)