Skip to content

Commit 405f31e

Browse files
fix tests
1 parent d746794 commit 405f31e

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

test/test_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,10 +1510,13 @@ class TestIouBase:
15101510
@staticmethod
15111511
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected, fmt="xyxy"):
15121512
for dtype in dtypes:
1513-
actual_box1 = ops.box_convert(torch.tensor(actual_box1, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
1514-
actual_box2 = ops.box_convert(torch.tensor(actual_box2, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
1513+
_actual_box1 = ops.box_convert(torch.tensor(actual_box1, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
1514+
_actual_box2 = ops.box_convert(torch.tensor(actual_box2, dtype=dtype), in_fmt="xyxy", out_fmt=fmt)
15151515
expected_box = torch.tensor(expected)
1516-
out = target_fn(actual_box1, actual_box2)
1516+
out = target_fn(
1517+
_actual_box1,
1518+
_actual_box2,
1519+
)
15171520
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
15181521

15191522
@staticmethod
@@ -1569,7 +1572,16 @@ def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected, fmt):
15691572

15701573
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
15711574
def test_iou_jit(self, fmt):
1572-
self._run_jit_test(partial(ops.box_iou, fmt=fmt), INT_BOXES, fmt)
1575+
class IoUJit(torch.nn.Module):
1576+
def __init__(self, fmt):
1577+
super().__init__()
1578+
self.iou = ops.box_iou
1579+
self.fmt = fmt
1580+
1581+
def forward(self, boxes1, boxes2):
1582+
return self.iou(boxes1, boxes2)
1583+
1584+
self._run_jit_test(IoUJit(fmt=fmt), INT_BOXES, fmt)
15731585

15741586
@pytest.mark.parametrize("fmt", ["xyxy", "xywh", "cxcywh"])
15751587
def test_iou_cartesian(self, fmt):

torchvision/ops/boxes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple
324324
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
325325
rb = torch.min(
326326
boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:]
327-
) # [N,M,2]
327+
) # [...,N,M,2]
328328
else: # fmt == "cxcywh":
329329
lt = torch.max(
330330
boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2
@@ -333,8 +333,8 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple
333333
boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2
334334
) # [N,M,2]
335335

336-
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
337-
inter = wh[..., 0] * wh[..., 1] # [N,M] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
336+
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
337+
inter = wh[..., 0] * wh[..., 1] # [N,M]
338338

339339
union = area1[..., None] + area2[..., None, :] - inter
340340

0 commit comments

Comments
 (0)