Skip to content

Commit 0eba325

Browse files
committed
Make many box ops batch-dim compatible. Add test for batched calculations.
Signed-off-by: Bryce Ferenczi <[email protected]>
1 parent 5f03dc5 commit 0eba325

File tree

2 files changed

+63
-40
lines changed

2 files changed

+63
-40
lines changed

test/test_ops.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,15 +1073,15 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None):
10731073
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
10741074

10751075
torch.testing.assert_close(
1076-
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1076+
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres: \n{res}\nexpected: \n{expected}"
10771077
)
10781078

10791079
# no modulation test
10801080
res = layer(x, offset)
10811081
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
10821082

10831083
torch.testing.assert_close(
1084-
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1084+
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres: \n{res}\nexpected: \n{expected}"
10851085
)
10861086

10871087
def test_wrong_sizes(self):
@@ -1468,7 +1468,7 @@ def test_box_area_jit(self):
14681468
]
14691469

14701470

1471-
def gen_box(size, dtype=torch.float):
1471+
def gen_box(size, dtype=torch.float) -> Tensor:
14721472
xy1 = torch.rand((size, 2), dtype=dtype)
14731473
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
14741474
return torch.cat([xy1, xy2], axis=-1)
@@ -1510,6 +1510,14 @@ def _run_cartesian_test(target_fn: Callable):
15101510
b = target_fn(boxes1, boxes2)
15111511
torch.testing.assert_close(a, b)
15121512

1513+
@staticmethod
1514+
def _run_batch_test(target_fn: Callable):
1515+
boxes1 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
1516+
boxes2 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
1517+
native: Tensor = target_fn(boxes1, boxes2)
1518+
iterative: Tensor = torch.stack([target_fn(*pairs) for pairs in zip(boxes1, boxes2)], dim=0)
1519+
torch.testing.assert_close(native, iterative)
1520+
15131521

15141522
class TestBoxIou(TestIouBase):
15151523
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
@@ -1532,6 +1540,9 @@ def test_iou_jit(self):
15321540
def test_iou_cartesian(self):
15331541
self._run_cartesian_test(ops.box_iou)
15341542

1543+
def test_iou_batch(self):
1544+
self._run_batch_test(ops.box_iou)
1545+
15351546

15361547
class TestGeneralizedBoxIou(TestIouBase):
15371548
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
@@ -1554,6 +1565,9 @@ def test_iou_jit(self):
15541565
def test_iou_cartesian(self):
15551566
self._run_cartesian_test(ops.generalized_box_iou)
15561567

1568+
def test_iou_batch(self):
1569+
self._run_batch_test(ops.generalized_box_iou)
1570+
15571571

15581572
class TestDistanceBoxIoU(TestIouBase):
15591573
int_expected = [
@@ -1581,6 +1595,9 @@ def test_iou_jit(self):
15811595
def test_iou_cartesian(self):
15821596
self._run_cartesian_test(ops.distance_box_iou)
15831597

1598+
def test_iou_batch(self):
1599+
self._run_batch_test(ops.distance_box_iou)
1600+
15841601

15851602
class TestCompleteBoxIou(TestIouBase):
15861603
int_expected = [
@@ -1608,6 +1625,9 @@ def test_iou_jit(self):
16081625
def test_iou_cartesian(self):
16091626
self._run_cartesian_test(ops.complete_box_iou)
16101627

1628+
def test_iou_batch(self):
1629+
self._run_batch_test(ops.complete_box_iou)
1630+
16111631

16121632
def get_boxes(dtype, device):
16131633
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)

torchvision/ops/boxes.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
130130
the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead.
131131
132132
Args:
133-
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
133+
boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
134134
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
135135
min_size (float): minimum size
136136
@@ -140,7 +140,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
140140
"""
141141
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
142142
_log_api_usage_once(remove_small_boxes)
143-
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
143+
ws, hs = boxes[..., 2] - boxes[..., 0], boxes[..., 3] - boxes[..., 1]
144144
keep = (ws >= min_size) & (hs >= min_size)
145145
keep = torch.where(keep)[0]
146146
return keep
@@ -155,12 +155,12 @@ def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor:
155155
the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead.
156156
157157
Args:
158-
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
158+
boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
159159
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
160160
size (Tuple[height, width]): size of the image
161161
162162
Returns:
163-
Tensor[N, 4]: clipped boxes
163+
Tensor[..., 4]: clipped boxes
164164
"""
165165
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
166166
_log_api_usage_once(clip_boxes_to_image)
@@ -276,7 +276,7 @@ def box_area(boxes: Tensor) -> Tensor:
276276
(x1, y1, x2, y2) coordinates.
277277
278278
Args:
279-
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
279+
boxes (Tensor[..., 4]): boxes for which the area will be computed. They
280280
are expected to be in (x1, y1, x2, y2) format with
281281
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
282282
@@ -286,7 +286,7 @@ def box_area(boxes: Tensor) -> Tensor:
286286
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
287287
_log_api_usage_once(box_area)
288288
boxes = _upcast(boxes)
289-
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
289+
return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
290290

291291

292292
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
@@ -295,13 +295,13 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
295295
area1 = box_area(boxes1)
296296
area2 = box_area(boxes2)
297297

298-
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
299-
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
298+
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
299+
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
300300

301301
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
302-
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
302+
inter = wh[..., 0] * wh[..., 1] # [N,M]
303303

304-
union = area1[:, None] + area2 - inter
304+
union = area1[..., None] + area2[..., None, :] - inter
305305

306306
return inter, union
307307

@@ -314,11 +314,12 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
314314
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
315315
316316
Args:
317-
boxes1 (Tensor[N, 4]): first set of boxes
318-
boxes2 (Tensor[M, 4]): second set of boxes
317+
boxes1 (Tensor[..., N, 4]): first set of boxes
318+
boxes2 (Tensor[..., M, 4]): second set of boxes
319319
320320
Returns:
321-
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
321+
Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
322+
in boxes1 and boxes2
322323
"""
323324
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
324325
_log_api_usage_once(box_iou)
@@ -336,11 +337,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
336337
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
337338
338339
Args:
339-
boxes1 (Tensor[N, 4]): first set of boxes
340-
boxes2 (Tensor[M, 4]): second set of boxes
340+
boxes1 (Tensor[..., N, 4]): first set of boxes
341+
boxes2 (Tensor[..., M, 4]): second set of boxes
341342
342343
Returns:
343-
Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
344+
Tensor[..., N, M]: the NxM matrix containing the pairwise generalized IoU values
344345
for every element in boxes1 and boxes2
345346
"""
346347
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@@ -349,11 +350,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
349350
inter, union = _box_inter_union(boxes1, boxes2)
350351
iou = inter / union
351352

352-
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
353-
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
353+
lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
354+
rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
354355

355356
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
356-
areai = whi[:, :, 0] * whi[:, :, 1]
357+
areai = whi[..., 0] * whi[..., 1]
357358

358359
return iou - (areai - union) / areai
359360

@@ -364,11 +365,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
364365
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
365366
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
366367
Args:
367-
boxes1 (Tensor[N, 4]): first set of boxes
368-
boxes2 (Tensor[M, 4]): second set of boxes
368+
boxes1 (Tensor[..., N, 4]): first set of boxes
369+
boxes2 (Tensor[..., M, 4]): second set of boxes
369370
eps (float, optional): small number to prevent division by zero. Default: 1e-7
370371
Returns:
371-
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
372+
Tensor[..., N, M]: the NxM matrix containing the pairwise complete IoU values
372373
for every element in boxes1 and boxes2
373374
"""
374375
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@@ -379,11 +380,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
379380

380381
diou, iou = _box_diou_iou(boxes1, boxes2, eps)
381382

382-
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
383-
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
383+
w_pred = boxes1[..., None, 2] - boxes1[..., None, 0]
384+
h_pred = boxes1[..., None, 3] - boxes1[..., None, 1]
384385

385-
w_gt = boxes2[:, 2] - boxes2[:, 0]
386-
h_gt = boxes2[:, 3] - boxes2[:, 1]
386+
w_gt = boxes2[..., None, :, 2] - boxes2[..., None, :, 0]
387+
h_gt = boxes2[..., None, :, 3] - boxes2[..., None, :, 1]
387388

388389
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
389390
with torch.no_grad():
@@ -399,12 +400,12 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
399400
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
400401
401402
Args:
402-
boxes1 (Tensor[N, 4]): first set of boxes
403-
boxes2 (Tensor[M, 4]): second set of boxes
403+
boxes1 (Tensor[..., N, 4]): first set of boxes
404+
boxes2 (Tensor[..., M, 4]): second set of boxes
404405
eps (float, optional): small number to prevent division by zero. Default: 1e-7
405406
406407
Returns:
407-
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
408+
Tensor[..., N, M]: the NxM matrix containing the pairwise distance IoU values
408409
for every element in boxes1 and boxes2
409410
"""
410411
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@@ -419,17 +420,19 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
419420
def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> tuple[Tensor, Tensor]:
420421

421422
iou = box_iou(boxes1, boxes2)
422-
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
423-
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
423+
lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
424+
rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
424425
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
425-
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
426+
diagonal_distance_squared = (whi[..., 0] ** 2) + (whi[..., 1] ** 2) + eps
426427
# centers of boxes
427-
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
428-
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
429-
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
430-
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
428+
x_p = (boxes1[..., 0] + boxes1[..., 2]) / 2
429+
y_p = (boxes1[..., 1] + boxes1[..., 3]) / 2
430+
x_g = (boxes2[..., 0] + boxes2[..., 2]) / 2
431+
y_g = (boxes2[..., 1] + boxes2[..., 3]) / 2
431432
# The distance between boxes' centers squared.
432-
centers_distance_squared = (_upcast(x_p[:, None] - x_g[None, :]) ** 2) + (_upcast(y_p[:, None] - y_g[None, :]) ** 2)
433+
centers_distance_squared = (_upcast(x_p[..., None] - x_g[..., None, :]) ** 2) + (
434+
_upcast(y_p[..., None] - y_g[..., None, :]) ** 2
435+
)
433436
# The distance IoU is the IoU penalized by a normalized
434437
# distance between boxes' centers squared.
435438
return iou - (centers_distance_squared / diagonal_distance_squared), iou

0 commit comments

Comments
 (0)