Skip to content

Add box_area_center and box_iou_center functions for cxcywh format with tests #8992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 160 additions & 12 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _helper_boxes_shape(self, func):
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
func(a, boxes, output_size=(2, 2))

# test boxes as List[Tensor[N, 4]]
# test boxes as list[Tensor[N, 4]]
with pytest.raises(AssertionError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
Expand Down Expand Up @@ -1445,9 +1445,9 @@ def test_bbox_convert_jit(self):
torch.testing.assert_close(scripted_cxcywh, box_cxcywh)


class TestBoxArea:
class TestBoxAreaXYXY:
def area_check(self, box, expected, atol=1e-4):
out = ops.box_area(box)
out = ops.box_area(box, fmt="xyxy")
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
Expand All @@ -1472,12 +1472,54 @@ def test_float16_box(self):

def test_box_area_jit(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
expected = ops.box_area(box_tensor)
expected = ops.box_area(box_tensor, fmt="xyxy")
scripted_fn = torch.jit.script(ops.box_area)
scripted_area = scripted_fn(box_tensor)
torch.testing.assert_close(scripted_area, expected)


class TestBoxAreaCXCYWH:
def area_check(self, box, expected, atol=1e-4):
out = ops.box_area(box, fmt="cxcywh")
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
def test_int_boxes(self, dtype):
box_tensor = ops.box_convert(
torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh"
)
expected = torch.tensor([10000, 0], dtype=torch.int32)
self.area_check(box_tensor, expected)

@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_float_boxes(self, dtype):
box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh")
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
self.area_check(box_tensor, expected)

def test_float16_box(self):
box_tensor = ops.box_convert(
torch.tensor(
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]],
dtype=torch.float16,
),
in_fmt="xyxy",
out_fmt="cxcywh",
)

expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
self.area_check(box_tensor, expected, atol=0.01)

def test_box_area_jit(self):
box_tensor = ops.box_convert(
torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt="cxcywh"
)
expected = ops.box_area(box_tensor, fmt="cxcywh")
scripted_fn = torch.jit.script(ops.box_area)
scripted_area = scripted_fn(box_tensor, fmt="cxcywh")
torch.testing.assert_close(scripted_area, expected)


INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
Expand All @@ -1486,29 +1528,37 @@ def test_box_area_jit(self):
[279.2440, 197.9812, 1189.4746, 849.2019],
]

INT_BOXES_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100], [10, 10, 20, 20]]
INT_BOXES2_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100]]
FLOAT_BOXES_CXCYWH = [
[739.4324, 518.5154, 908.1572, 665.8793],
[738.8228, 519.9021, 907.3512, 662.3295],
[734.3593, 523.5916, 910.2306, 651.2207],
]


def gen_box(size, dtype=torch.float) -> Tensor:
xy1 = torch.rand((size, 2), dtype=dtype)
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
return torch.cat([xy1, xy2], axis=-1)


class TestIouBase:
class TestIouXYXYBase:
@staticmethod
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
for dtype in dtypes:
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
expected_box = torch.tensor(expected)
out = target_fn(actual_box1, actual_box2)
out = target_fn(actual_box1, actual_box2, fmt="xyxy")
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)

@staticmethod
def _run_jit_test(target_fn: Callable, actual_box: list):
box_tensor = torch.tensor(actual_box, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor)
expected = target_fn(box_tensor, box_tensor, fmt="xyxy")
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
scripted_out = scripted_fn(box_tensor, box_tensor, fmt="xyxy")
torch.testing.assert_close(scripted_out, expected)

@staticmethod
Expand All @@ -1518,15 +1568,15 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
result = torch.zeros((N, M))
for i in range(N):
for j in range(M):
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="xyxy")
return result

@staticmethod
def _run_cartesian_test(target_fn: Callable):
boxes1 = gen_box(5)
boxes2 = gen_box(7)
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2)
a = TestIouXYXYBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2, fmt="xyxy")
torch.testing.assert_close(a, b)

@staticmethod
Expand All @@ -1538,7 +1588,7 @@ def _run_batch_test(target_fn: Callable):
torch.testing.assert_close(native, iterative)


class TestBoxIou(TestIouBase):
class TestBoxIouXYXY(TestIouXYXYBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

Expand All @@ -1563,6 +1613,104 @@ def test_iou_batch(self):
self._run_batch_test(ops.box_iou)


class TestIouCXCYWHBase:
@staticmethod
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
for dtype in dtypes:
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
expected_box = torch.tensor(expected)
out = target_fn(actual_box1, actual_box2, fmt="cxcywh")
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)

@staticmethod
def _run_jit_test(target_fn: Callable, actual_box: list):
box_tensor = torch.tensor(actual_box, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor, fmt="cxcywh")
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor, fmt="cxcywh")
torch.testing.assert_close(scripted_out, expected)

@staticmethod
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
N = boxes1.size(0)
M = boxes2.size(0)
result = torch.zeros((N, M))
for i in range(N):
for j in range(M):
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="cxcywh")
return result

@staticmethod
def _run_cartesian_test(target_fn: Callable):
boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh")
boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh")
a = TestIouCXCYWHBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2, fmt="cxcywh")
torch.testing.assert_close(a, b)


class TestBoxIouCXCYWH(TestIouCXCYWHBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(
INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected
),
pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.box_iou, INT_BOXES_CXCYWH)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou)


class TestIouBase:
@staticmethod
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
for dtype in dtypes:
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
expected_box = torch.tensor(expected)
out = target_fn(actual_box1, actual_box2)
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)

@staticmethod
def _run_jit_test(target_fn: Callable, actual_box: list):
box_tensor = torch.tensor(actual_box, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor)
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
torch.testing.assert_close(scripted_out, expected)

@staticmethod
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
N = boxes1.size(0)
M = boxes2.size(0)
result = torch.zeros((N, M))
for i in range(N):
for j in range(M):
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
return result

@staticmethod
def _run_cartesian_test(target_fn: Callable):
boxes1 = gen_box(5)
boxes2 = gen_box(7)
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2)
torch.testing.assert_close(a, b)


class TestGeneralizedBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
Expand Down
86 changes: 64 additions & 22 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,60 +270,102 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
return boxes


def box_area(boxes: Tensor) -> Tensor:
def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by their
(x1, y1, x2, y2) coordinates.
Computes the area of a set of bounding boxes from a given format.

Args:
boxes (Tensor[..., 4]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
boxes (Tensor[..., 4]): boxes for which the area will be computed.
fmt (str): Format of the input boxes.
Default is "xyxy" to preserve backward compatibility.
Supported formats are "xyxy", "xywh", and "cxcywh".

Returns:
Tensor[N]: the area for each box
Tensor[N]: Tensor containing the area for each box.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_area)
allowed_fmts = (
"xyxy",
"xywh",
"cxcywh",
)
if fmt not in allowed_fmts:
raise ValueError(f"Unsupported Bounding Box area for given format {fmt}")
boxes = _upcast(boxes)
return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
if fmt == "xyxy":
area = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
else:
# For formats with width and height, area = width * height
# Supported: cxcywh, xywh
area = boxes[..., 2] * boxes[..., 3]

return area


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]:
area1 = box_area(boxes1, fmt=fmt)
area2 = box_area(boxes2, fmt=fmt)

wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[..., 0] * wh[..., 1] # [N,M]
allowed_fmts = (
"xyxy",
"xywh",
"cxcywh",
)
if fmt not in allowed_fmts:
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}.")

if fmt == "xyxy":
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
elif fmt == "xywh":
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
rb = torch.min(
boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:]
) # [N,M,2]
else: # fmt == "cxcywh":
lt = torch.max(
boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2
) # [N,M,2]
rb = torch.min(
boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2
) # [N,M,2]

wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[..., 0] * wh[..., 1] # [N,M] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[..., None] + area2[..., None, :] - inter

return inter, union


def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
"""
Return intersection-over-union (Jaccard index) between two sets of boxes.

Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.

Args:
boxes1 (Tensor[..., N, 4]): first set of boxes
boxes2 (Tensor[..., M, 4]): second set of boxes
fmt (str): Format of the input boxes.
Default is "xyxy" to preserve backward compatibility.
Supported formats are "xyxy", "xywh", and "cxcywh".

Returns:
Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_iou)
inter, union = _box_inter_union(boxes1, boxes2)
allowed_fmts = (
"xyxy",
"xywh",
"cxcywh",
)
if fmt not in allowed_fmts:
raise ValueError(f"Unsupported Box IoU Calculation for given format {fmt}.")
inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt)
iou = inter / union
return iou

Expand Down
Loading