Skip to content

Commit f515c45

Browse files
5had3zscotts
andauthored
Batched Box Ops (#9058)
Signed-off-by: Bryce Ferenczi <[email protected]> Co-authored-by: Scott Schneider <[email protected]>
1 parent 4a6ae15 commit f515c45

File tree

2 files changed

+63
-40
lines changed

2 files changed

+63
-40
lines changed

test/test_ops.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,15 +1073,15 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None):
10731073
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
10741074

10751075
torch.testing.assert_close(
1076-
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1076+
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres: \n{res}\nexpected: \n{expected}"
10771077
)
10781078

10791079
# no modulation test
10801080
res = layer(x, offset)
10811081
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
10821082

10831083
torch.testing.assert_close(
1084-
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1084+
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres: \n{res}\nexpected: \n{expected}"
10851085
)
10861086

10871087
def test_wrong_sizes(self):
@@ -1487,7 +1487,7 @@ def test_box_area_jit(self):
14871487
]
14881488

14891489

1490-
def gen_box(size, dtype=torch.float):
1490+
def gen_box(size, dtype=torch.float) -> Tensor:
14911491
xy1 = torch.rand((size, 2), dtype=dtype)
14921492
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
14931493
return torch.cat([xy1, xy2], axis=-1)
@@ -1529,6 +1529,14 @@ def _run_cartesian_test(target_fn: Callable):
15291529
b = target_fn(boxes1, boxes2)
15301530
torch.testing.assert_close(a, b)
15311531

1532+
@staticmethod
1533+
def _run_batch_test(target_fn: Callable):
1534+
boxes1 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
1535+
boxes2 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
1536+
native: Tensor = target_fn(boxes1, boxes2)
1537+
iterative: Tensor = torch.stack([target_fn(*pairs) for pairs in zip(boxes1, boxes2)], dim=0)
1538+
torch.testing.assert_close(native, iterative)
1539+
15321540

15331541
class TestBoxIou(TestIouBase):
15341542
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
@@ -1551,6 +1559,9 @@ def test_iou_jit(self):
15511559
def test_iou_cartesian(self):
15521560
self._run_cartesian_test(ops.box_iou)
15531561

1562+
def test_iou_batch(self):
1563+
self._run_batch_test(ops.box_iou)
1564+
15541565

15551566
class TestGeneralizedBoxIou(TestIouBase):
15561567
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
@@ -1573,6 +1584,9 @@ def test_iou_jit(self):
15731584
def test_iou_cartesian(self):
15741585
self._run_cartesian_test(ops.generalized_box_iou)
15751586

1587+
def test_iou_batch(self):
1588+
self._run_batch_test(ops.generalized_box_iou)
1589+
15761590

15771591
class TestDistanceBoxIoU(TestIouBase):
15781592
int_expected = [
@@ -1600,6 +1614,9 @@ def test_iou_jit(self):
16001614
def test_iou_cartesian(self):
16011615
self._run_cartesian_test(ops.distance_box_iou)
16021616

1617+
def test_iou_batch(self):
1618+
self._run_batch_test(ops.distance_box_iou)
1619+
16031620

16041621
class TestCompleteBoxIou(TestIouBase):
16051622
int_expected = [
@@ -1627,6 +1644,9 @@ def test_iou_jit(self):
16271644
def test_iou_cartesian(self):
16281645
self._run_cartesian_test(ops.complete_box_iou)
16291646

1647+
def test_iou_batch(self):
1648+
self._run_batch_test(ops.complete_box_iou)
1649+
16301650

16311651
def get_boxes(dtype, device):
16321652
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)

torchvision/ops/boxes.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
130130
the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead.
131131
132132
Args:
133-
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
133+
boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
134134
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
135135
min_size (float): minimum size
136136
@@ -140,7 +140,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
140140
"""
141141
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
142142
_log_api_usage_once(remove_small_boxes)
143-
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
143+
ws, hs = boxes[..., 2] - boxes[..., 0], boxes[..., 3] - boxes[..., 1]
144144
keep = (ws >= min_size) & (hs >= min_size)
145145
keep = torch.where(keep)[0]
146146
return keep
@@ -155,12 +155,12 @@ def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor:
155155
the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead.
156156
157157
Args:
158-
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
158+
boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
159159
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
160160
size (Tuple[height, width]): size of the image
161161
162162
Returns:
163-
Tensor[N, 4]: clipped boxes
163+
Tensor[..., 4]: clipped boxes
164164
"""
165165
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
166166
_log_api_usage_once(clip_boxes_to_image)
@@ -276,7 +276,7 @@ def box_area(boxes: Tensor) -> Tensor:
276276
(x1, y1, x2, y2) coordinates.
277277
278278
Args:
279-
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
279+
boxes (Tensor[..., 4]): boxes for which the area will be computed. They
280280
are expected to be in (x1, y1, x2, y2) format with
281281
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
282282
@@ -286,7 +286,7 @@ def box_area(boxes: Tensor) -> Tensor:
286286
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
287287
_log_api_usage_once(box_area)
288288
boxes = _upcast(boxes)
289-
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
289+
return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
290290

291291

292292
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
@@ -295,13 +295,13 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
295295
area1 = box_area(boxes1)
296296
area2 = box_area(boxes2)
297297

298-
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
299-
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
298+
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
299+
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
300300

301301
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
302-
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
302+
inter = wh[..., 0] * wh[..., 1] # [N,M]
303303

304-
union = area1[:, None] + area2 - inter
304+
union = area1[..., None] + area2[..., None, :] - inter
305305

306306
return inter, union
307307

@@ -314,11 +314,12 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
314314
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
315315
316316
Args:
317-
boxes1 (Tensor[N, 4]): first set of boxes
318-
boxes2 (Tensor[M, 4]): second set of boxes
317+
boxes1 (Tensor[..., N, 4]): first set of boxes
318+
boxes2 (Tensor[..., M, 4]): second set of boxes
319319
320320
Returns:
321-
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
321+
Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
322+
in boxes1 and boxes2
322323
"""
323324
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
324325
_log_api_usage_once(box_iou)
@@ -336,11 +337,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
336337
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
337338
338339
Args:
339-
boxes1 (Tensor[N, 4]): first set of boxes
340-
boxes2 (Tensor[M, 4]): second set of boxes
340+
boxes1 (Tensor[..., N, 4]): first set of boxes
341+
boxes2 (Tensor[..., M, 4]): second set of boxes
341342
342343
Returns:
343-
Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
344+
Tensor[..., N, M]: the NxM matrix containing the pairwise generalized IoU values
344345
for every element in boxes1 and boxes2
345346
"""
346347
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@@ -349,11 +350,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
349350
inter, union = _box_inter_union(boxes1, boxes2)
350351
iou = inter / union
351352

352-
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
353-
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
353+
lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
354+
rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
354355

355356
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
356-
areai = whi[:, :, 0] * whi[:, :, 1]
357+
areai = whi[..., 0] * whi[..., 1]
357358

358359
return iou - (areai - union) / areai
359360

@@ -364,11 +365,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
364365
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
365366
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
366367
Args:
367-
boxes1 (Tensor[N, 4]): first set of boxes
368-
boxes2 (Tensor[M, 4]): second set of boxes
368+
boxes1 (Tensor[..., N, 4]): first set of boxes
369+
boxes2 (Tensor[..., M, 4]): second set of boxes
369370
eps (float, optional): small number to prevent division by zero. Default: 1e-7
370371
Returns:
371-
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
372+
Tensor[..., N, M]: the NxM matrix containing the pairwise complete IoU values
372373
for every element in boxes1 and boxes2
373374
"""
374375
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@@ -379,11 +380,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
379380

380381
diou, iou = _box_diou_iou(boxes1, boxes2, eps)
381382

382-
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
383-
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
383+
w_pred = boxes1[..., None, 2] - boxes1[..., None, 0]
384+
h_pred = boxes1[..., None, 3] - boxes1[..., None, 1]
384385

385-
w_gt = boxes2[:, 2] - boxes2[:, 0]
386-
h_gt = boxes2[:, 3] - boxes2[:, 1]
386+
w_gt = boxes2[..., None, :, 2] - boxes2[..., None, :, 0]
387+
h_gt = boxes2[..., None, :, 3] - boxes2[..., None, :, 1]
387388

388389
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
389390
with torch.no_grad():
@@ -399,12 +400,12 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
399400
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
400401
401402
Args:
402-
boxes1 (Tensor[N, 4]): first set of boxes
403-
boxes2 (Tensor[M, 4]): second set of boxes
403+
boxes1 (Tensor[..., N, 4]): first set of boxes
404+
boxes2 (Tensor[..., M, 4]): second set of boxes
404405
eps (float, optional): small number to prevent division by zero. Default: 1e-7
405406
406407
Returns:
407-
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
408+
Tensor[..., N, M]: the NxM matrix containing the pairwise distance IoU values
408409
for every element in boxes1 and boxes2
409410
"""
410411
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@@ -419,17 +420,19 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
419420
def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> tuple[Tensor, Tensor]:
420421

421422
iou = box_iou(boxes1, boxes2)
422-
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
423-
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
423+
lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
424+
rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
424425
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
425-
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
426+
diagonal_distance_squared = (whi[..., 0] ** 2) + (whi[..., 1] ** 2) + eps
426427
# centers of boxes
427-
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
428-
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
429-
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
430-
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
428+
x_p = (boxes1[..., 0] + boxes1[..., 2]) / 2
429+
y_p = (boxes1[..., 1] + boxes1[..., 3]) / 2
430+
x_g = (boxes2[..., 0] + boxes2[..., 2]) / 2
431+
y_g = (boxes2[..., 1] + boxes2[..., 3]) / 2
431432
# The distance between boxes' centers squared.
432-
centers_distance_squared = (_upcast(x_p[:, None] - x_g[None, :]) ** 2) + (_upcast(y_p[:, None] - y_g[None, :]) ** 2)
433+
centers_distance_squared = (_upcast(x_p[..., None] - x_g[..., None, :]) ** 2) + (
434+
_upcast(y_p[..., None] - y_g[..., None, :]) ** 2
435+
)
433436
# The distance IoU is the IoU penalized by a normalized
434437
# distance between boxes' centers squared.
435438
return iou - (centers_distance_squared / diagonal_distance_squared), iou

0 commit comments

Comments
 (0)