Skip to content

Commit 8bc72b6

Browse files
Keep batch dimension
1 parent adcc878 commit 8bc72b6

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

torchvision/ops/boxes.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -318,19 +318,25 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple
318318
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.")
319319

320320
if fmt == "xyxy":
321-
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
322-
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
321+
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
322+
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
323323
elif fmt == "xywh":
324-
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
325-
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:], boxes2[:, :2] + boxes2[:, 2:]) # [N,M,2]
324+
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
325+
rb = torch.min(
326+
boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :2] + boxes2[..., None, 2:]
327+
) # [N,M,2]
326328
else: # fmt == "cxcywh":
327-
lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2]
328-
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]
329+
lt = torch.max(
330+
boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :2] - boxes2[..., None, 2:] / 2
331+
) # [N,M,2]
332+
rb = torch.min(
333+
boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :2] + boxes2[..., None, 2:] / 2
334+
) # [N,M,2]
329335

330-
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
331-
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
336+
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
337+
inter = wh[..., 0] * wh[..., 1] # [N,M] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
332338

333-
union = area1[:, None] + area2 - inter
339+
union = area1[..., None] + area2[..., None, :] - inter
334340

335341
return inter, union
336342

0 commit comments

Comments
 (0)