Skip to content

Commit 64666c7

Browse files
Add box_area_center and box_iou_center functions for cxcywh format with tests (#8992)
Co-authored-by: Antoine Simoulin <[email protected]>
1 parent 97920a5 commit 64666c7

File tree

2 files changed

+150
-63
lines changed

2 files changed

+150
-63
lines changed

test/test_ops.py

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import os
33
from abc import ABC, abstractmethod
4-
from functools import lru_cache
4+
from functools import lru_cache, partial
55
from itertools import product
66
from typing import Callable
77

@@ -242,7 +242,7 @@ def _helper_boxes_shape(self, func):
242242
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
243243
func(a, boxes, output_size=(2, 2))
244244

245-
# test boxes as List[Tensor[N, 4]]
245+
# test boxes as list[Tensor[N, 4]]
246246
with pytest.raises(AssertionError):
247247
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
248248
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
@@ -1446,34 +1446,60 @@ def test_bbox_convert_jit(self):
14461446

14471447

14481448
class TestBoxArea:
1449-
def area_check(self, box, expected, atol=1e-4):
1450-
out = ops.box_area(box)
1449+
def area_check(self, box, expected, fmt="xyxy", atol=1e-4):
1450+
out = ops.box_area(box, fmt=fmt)
14511451
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)
14521452

14531453
@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
1454-
def test_int_boxes(self, dtype):
1455-
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
1454+
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
1455+
def test_int_boxes(self, dtype, fmt):
1456+
box_tensor = ops.box_convert(
1457+
torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), in_fmt="xyxy", out_fmt=fmt
1458+
)
14561459
expected = torch.tensor([10000, 0], dtype=torch.int32)
1457-
self.area_check(box_tensor, expected)
1460+
self.area_check(box_tensor, expected, fmt)
14581461

14591462
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
1460-
def test_float_boxes(self, dtype):
1461-
box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
1463+
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
1464+
def test_float_boxes(self, dtype, fmt):
1465+
box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
14621466
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
1463-
self.area_check(box_tensor, expected)
1464-
1465-
def test_float16_box(self):
1466-
box_tensor = torch.tensor(
1467-
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
1467+
self.area_check(box_tensor, expected, fmt)
1468+
1469+
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
1470+
def test_float16_box(self, fmt):
1471+
box_tensor = ops.box_convert(
1472+
torch.tensor(
1473+
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]],
1474+
dtype=torch.float16,
1475+
),
1476+
in_fmt="xyxy",
1477+
out_fmt=fmt,
14681478
)
14691479

14701480
expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
1471-
self.area_check(box_tensor, expected, atol=0.01)
1481+
self.area_check(box_tensor, expected, fmt, atol=0.01)
1482+
1483+
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
1484+
def test_box_area_jit(self, fmt):
1485+
box_tensor = ops.box_convert(
1486+
torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt=fmt
1487+
)
1488+
expected = ops.box_area(box_tensor, fmt)
1489+
1490+
class BoxArea(torch.nn.Module):
1491+
# We are using this intermediate class
1492+
# since torchscript does not support
1493+
# neither partial nor lambda functions for this test.
1494+
def __init__(self, fmt):
1495+
super().__init__()
1496+
self.area = ops.box_area
1497+
self.fmt = fmt
14721498

1473-
def test_box_area_jit(self):
1474-
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
1475-
expected = ops.box_area(box_tensor)
1476-
scripted_fn = torch.jit.script(ops.box_area)
1499+
def forward(self, boxes):
1500+
return self.area(boxes, self.fmt)
1501+
1502+
scripted_fn = torch.jit.script(BoxArea(fmt))
14771503
scripted_area = scripted_fn(box_tensor)
14781504
torch.testing.assert_close(scripted_area, expected)
14791505

@@ -1487,25 +1513,28 @@ def test_box_area_jit(self):
14871513
]
14881514

14891515

1490-
def gen_box(size, dtype=torch.float) -> Tensor:
1516+
def gen_box(size, dtype=torch.float, fmt="xyxy") -> Tensor:
14911517
xy1 = torch.rand((size, 2), dtype=dtype)
14921518
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
1493-
return torch.cat([xy1, xy2], axis=-1)
1519+
return ops.box_convert(torch.cat([xy1, xy2], axis=-1), in_fmt="xyxy", out_fmt=fmt)
14941520

14951521

14961522
class TestIouBase:
14971523
@staticmethod
1498-
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1524+
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected, fmt="xyxy"):
14991525
for dtype in dtypes:
1500-
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
1501-
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1526+
_actual_box1 = ops.box_convert(torch.tensor(actual_box1, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
1527+
_actual_box2 = ops.box_convert(torch.tensor(actual_box2, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
15021528
expected_box = torch.tensor(expected)
1503-
out = target_fn(actual_box1, actual_box2)
1529+
out = target_fn(
1530+
_actual_box1,
1531+
_actual_box2,
1532+
)
15041533
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
15051534

15061535
@staticmethod
1507-
def _run_jit_test(target_fn: Callable, actual_box: list):
1508-
box_tensor = torch.tensor(actual_box, dtype=torch.float)
1536+
def _run_jit_test(target_fn: Callable, actual_box: list, fmt="xyxy"):
1537+
box_tensor = ops.box_convert(torch.tensor(actual_box, dtype=torch.float), in_fmt="xyxy", out_fmt=fmt)
15091538
expected = target_fn(box_tensor, box_tensor)
15101539
scripted_fn = torch.jit.script(target_fn)
15111540
scripted_out = scripted_fn(box_tensor, box_tensor)
@@ -1522,17 +1551,17 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
15221551
return result
15231552

15241553
@staticmethod
1525-
def _run_cartesian_test(target_fn: Callable):
1526-
boxes1 = gen_box(5)
1527-
boxes2 = gen_box(7)
1554+
def _run_cartesian_test(target_fn: Callable, fmt: str = "xyxy"):
1555+
boxes1 = gen_box(5, fmt=fmt)
1556+
boxes2 = gen_box(7, fmt=fmt)
15281557
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
15291558
b = target_fn(boxes1, boxes2)
15301559
torch.testing.assert_close(a, b)
15311560

15321561
@staticmethod
1533-
def _run_batch_test(target_fn: Callable):
1534-
boxes1 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
1535-
boxes2 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
1562+
def _run_batch_test(target_fn: Callable, fmt: str = "xyxy"):
1563+
boxes1 = torch.stack([gen_box(5, fmt=fmt) for _ in range(3)], dim=0)
1564+
boxes2 = torch.stack([gen_box(5, fmt=fmt) for _ in range(3)], dim=0)
15361565
native: Tensor = target_fn(boxes1, boxes2)
15371566
iterative: Tensor = torch.stack([target_fn(*pairs) for pairs in zip(boxes1, boxes2)], dim=0)
15381567
torch.testing.assert_close(native, iterative)
@@ -1550,17 +1579,33 @@ class TestBoxIou(TestIouBase):
15501579
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
15511580
],
15521581
)
1553-
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1554-
self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1555-
1556-
def test_iou_jit(self):
1557-
self._run_jit_test(ops.box_iou, INT_BOXES)
1558-
1559-
def test_iou_cartesian(self):
1560-
self._run_cartesian_test(ops.box_iou)
1561-
1562-
def test_iou_batch(self):
1563-
self._run_batch_test(ops.box_iou)
1582+
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
1583+
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected, fmt):
1584+
self._run_test(partial(ops.box_iou, fmt=fmt), actual_box1, actual_box2, dtypes, atol, expected, fmt)
1585+
1586+
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
1587+
def test_iou_jit(self, fmt):
1588+
class IoUJit(torch.nn.Module):
1589+
# We are using this intermediate class
1590+
# since torchscript does not support
1591+
# neither partial nor lambda functions for this test.
1592+
def __init__(self, fmt):
1593+
super().__init__()
1594+
self.iou = ops.box_iou
1595+
self.fmt = fmt
1596+
1597+
def forward(self, boxes1, boxes2):
1598+
return self.iou(boxes1, boxes2, fmt=self.fmt)
1599+
1600+
self._run_jit_test(IoUJit(fmt=fmt), INT_BOXES, fmt)
1601+
1602+
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
1603+
def test_iou_cartesian(self, fmt):
1604+
self._run_cartesian_test(partial(ops.box_iou, fmt=fmt))
1605+
1606+
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
1607+
def test_iou_batch(self, fmt):
1608+
self._run_batch_test(partial(ops.box_iou, fmt=fmt))
15641609

15651610

15661611
class TestGeneralizedBoxIou(TestIouBase):

torchvision/ops/boxes.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -270,33 +270,68 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
270270
return boxes
271271

272272

273-
def box_area(boxes: Tensor) -> Tensor:
273+
def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
274274
"""
275-
Computes the area of a set of bounding boxes, which are specified by their
276-
(x1, y1, x2, y2) coordinates.
275+
Computes the area of a set of bounding boxes from a given format.
277276
278277
Args:
279-
boxes (Tensor[..., 4]): boxes for which the area will be computed. They
280-
are expected to be in (x1, y1, x2, y2) format with
281-
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
278+
boxes (Tensor[..., 4]): boxes for which the area will be computed.
279+
fmt (str): Format of the input boxes.
280+
Default is "xyxy" to preserve backward compatibility.
281+
Supported formats are "xyxy", "xywh", and "cxcywh".
282282
283283
Returns:
284-
Tensor[N]: the area for each box
284+
Tensor[N]: Tensor containing the area for each box.
285285
"""
286286
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
287287
_log_api_usage_once(box_area)
288+
allowed_fmts = (
289+
"xyxy",
290+
"xywh",
291+
"cxcywh",
292+
)
293+
if fmt not in allowed_fmts:
294+
raise ValueError(f"Unsupported Bounding Box area for given format {fmt}")
288295
boxes = _upcast(boxes)
289-
return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
296+
if fmt == "xyxy":
297+
area = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
298+
else:
299+
# For formats with width and height, area = width * height
300+
# Supported: cxcywh, xywh
301+
area = boxes[..., 2] * boxes[..., 3]
302+
303+
return area
290304

291305

292306
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
293307
# with slight modifications
294-
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
295-
area1 = box_area(boxes1)
296-
area2 = box_area(boxes2)
308+
def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]:
309+
area1 = box_area(boxes1, fmt=fmt)
310+
area2 = box_area(boxes2, fmt=fmt)
297311

298-
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
299-
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
312+
allowed_fmts = (
313+
"xyxy",
314+
"xywh",
315+
"cxcywh",
316+
)
317+
if fmt not in allowed_fmts:
318+
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}.")
319+
320+
if fmt == "xyxy":
321+
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
322+
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
323+
elif fmt == "xywh":
324+
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
325+
rb = torch.min(
326+
boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:]
327+
) # [...,N,M,2]
328+
else: # fmt == "cxcywh":
329+
lt = torch.max(
330+
boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2
331+
) # [N,M,2]
332+
rb = torch.min(
333+
boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2
334+
) # [N,M,2]
300335

301336
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
302337
inter = wh[..., 0] * wh[..., 1] # [N,M]
@@ -306,24 +341,31 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
306341
return inter, union
307342

308343

309-
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
344+
def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
310345
"""
311-
Return intersection-over-union (Jaccard index) between two sets of boxes.
312-
313-
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
314-
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
346+
Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
315347
316348
Args:
317349
boxes1 (Tensor[..., N, 4]): first set of boxes
318350
boxes2 (Tensor[..., M, 4]): second set of boxes
351+
fmt (str): Format of the input boxes.
352+
Default is "xyxy" to preserve backward compatibility.
353+
Supported formats are "xyxy", "xywh", and "cxcywh".
319354
320355
Returns:
321356
Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
322357
in boxes1 and boxes2
323358
"""
324359
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
325360
_log_api_usage_once(box_iou)
326-
inter, union = _box_inter_union(boxes1, boxes2)
361+
allowed_fmts = (
362+
"xyxy",
363+
"xywh",
364+
"cxcywh",
365+
)
366+
if fmt not in allowed_fmts:
367+
raise ValueError(f"Unsupported Box IoU Calculation for given format {fmt}.")
368+
inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt)
327369
iou = inter / union
328370
return iou
329371

0 commit comments

Comments
 (0)