Skip to content

Commit c509b11

Browse files
Keep batch dimension (2/2)
1 parent 8bc72b6 commit c509b11

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchvision/ops/boxes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,11 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
294294
raise ValueError(f"Unsupported Bounding Box area for given fmt {fmt}")
295295
boxes = _upcast(boxes)
296296
if fmt == "xyxy":
297-
area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
297+
area = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
298298
else:
299299
# For formats with width and height, area = width * height
300300
# Supported: cxcywh, xywh
301-
area = boxes[:, 2] * boxes[:, 3]
301+
area = boxes[..., 2] * boxes[..., 3]
302302

303303
return area
304304

@@ -323,14 +323,14 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple
323323
elif fmt == "xywh":
324324
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
325325
rb = torch.min(
326-
boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :2] + boxes2[..., None, 2:]
326+
boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:]
327327
) # [N,M,2]
328328
else: # fmt == "cxcywh":
329329
lt = torch.max(
330-
boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :2] - boxes2[..., None, 2:] / 2
330+
boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2
331331
) # [N,M,2]
332332
rb = torch.min(
333-
boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :2] + boxes2[..., None, 2:] / 2
333+
boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2
334334
) # [N,M,2]
335335

336336
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]

0 commit comments

Comments
 (0)