Skip to content

Commit d94e031

Browse files
committed
Fix is_rotated_bounding_format to accept str and fix test_bbox_format_dtype
1 parent 2e6a52b commit d94e031

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

test/test_tv_tensors.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ def test_bbox_instance(data, format):
6969
)
7070
@pytest.mark.parametrize("scripted", (False, True))
7171
def test_bbox_format(format, is_rotated_expected, scripted):
72-
if isinstance(format, str):
73-
format = tv_tensors.BoundingBoxFormat[(format.upper())]
74-
7572
fn = tv_tensors.is_rotated_bounding_format
7673
if scripted:
7774
fn = torch.jit.script(fn)
@@ -97,13 +94,12 @@ def test_bbox_format(format, is_rotated_expected, scripted):
9794
)
9895
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
9996
def test_bbox_format_dtype(format, support_integer_dtype, input_dtype):
100-
print(format, support_integer_dtype, input_dtype)
97+
tensor = torch.randint(0, 32, size=(5, 2), dtype=input_dtype)
10198
if not input_dtype.is_floating_point and not support_integer_dtype:
102-
pytest.xfail("Rotated bounding boxes should be floating point tensors")
103-
bboxes = tv_tensors.BoundingBoxes(torch.randint(0, 32, size=(5, 2)), format=format, canvas_size=(32, 32))
99+
with pytest.raises(ValueError, match="Rotated bounding boxes should be floating point tensors"):
100+
tv_tensors.BoundingBoxes(tensor, format=format, canvas_size=(32, 32))
104101
else:
105-
bboxes = tv_tensors.BoundingBoxes(torch.rand(size=(5, 2)), format=format, canvas_size=(32, 32))
106-
assert isinstance(bboxes, torch.Tensor)
102+
tv_tensors.BoundingBoxes(tensor, format=format, canvas_size=(32, 32))
107103

108104

109105
def test_bbox_dim_error():
@@ -437,5 +433,5 @@ def test_return_type_input():
437433

438434

439435
def test_box_clamping_mode_default():
440-
assert tv_tensors.BoundingBoxes([0, 0, 10, 10], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft"
441-
assert tv_tensors.BoundingBoxes([0, 0, 10, 10, 0], format="XYWHR", canvas_size=(100, 100)).clamping_mode == "soft"
436+
assert tv_tensors.BoundingBoxes([0., 0., 10., 10.], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft"
437+
assert tv_tensors.BoundingBoxes([0., 0., 10., 10., 0.], format="XYWHR", canvas_size=(100, 100)).clamping_mode == "soft"

torchvision/tv_tensors/_bounding_boxes.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ class BoundingBoxFormat(Enum):
4040

4141
# TODO: Once torchscript supports Enums with staticmethod
4242
# this can be put into BoundingBoxFormat as staticmethod
43-
def is_rotated_bounding_format(format: BoundingBoxFormat) -> bool:
44-
return (
45-
format == BoundingBoxFormat.XYWHR or format == BoundingBoxFormat.CXCYWHR or format == BoundingBoxFormat.XYXYXYXY
46-
)
43+
def is_rotated_bounding_format(format: Union[BoundingBoxFormat, str]) -> bool:
44+
if isinstance(format, BoundingBoxFormat):
45+
return (format == BoundingBoxFormat.XYWHR or format == BoundingBoxFormat.CXCYWHR or format == BoundingBoxFormat.XYXYXYXY)
46+
elif isinstance(format, str):
47+
return format in ("XYWHR", "CXCYWHR", "XYXYXYXY")
48+
else:
49+
raise ValueError(f"format should be str or BoundingBoxFormat, got {type(format)}")
4750

4851

4952
# TODOBB consider making this a Literal instead. Tried briefly and got
@@ -111,7 +114,7 @@ def __new__(
111114
) -> BoundingBoxes:
112115
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
113116
if not torch.is_floating_point(tensor) and is_rotated_bounding_format(format):
114-
raise ValueError("Rotated bounding boxes should be floating point tensors")
117+
raise ValueError(f"Rotated bounding boxes should be floating point tensors, got {tensor.dtype}.")
115118
return cls._wrap(tensor, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
116119

117120
@classmethod

0 commit comments

Comments
 (0)