Skip to content

Commit 3651f9e

Browse files
Remove dispatcher
1 parent 09ae6a0 commit 3651f9e

File tree

1 file changed

+57
-121
lines changed

1 file changed

+57
-121
lines changed

torchvision/ops/boxes.py

Lines changed: 57 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -275,130 +275,87 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
275275
Computes the area of a set of bounding boxes from a given format.
276276
277277
Args:
278-
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
279-
are expected to be in (x1, y1, x2, y2) format with
280-
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
281-
fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy"
278+
boxes (Tensor[N, 4]): Tensor containing N boxes.
279+
format (str): Format of the input boxes.
280+
Default is "xyxy" to preserve backward compatibility.
281+
Supported formats are "xyxy", "xywh", and "cxcywh".
282282
283283
Returns:
284-
Tensor[N]: the area for each box
285-
"""
286-
if fmt == "xyxy":
287-
boxes = box_area_xyxy(boxes=boxes)
288-
elif fmt == "cxcywh":
289-
boxes = box_area_cxcywh(boxes=boxes)
290-
else:
291-
raise ValueError(f"Unsupported Box Area Calculation for given fmt {fmt}")
292-
293-
return boxes
294-
295-
296-
def box_area_xyxy(boxes: Tensor) -> Tensor:
297-
"""
298-
Computes the area of a set of bounding boxes, which are specified by their
299-
(x1, y1, x2, y2) coordinates.
300-
301-
Args:
302-
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
303-
are expected to be in (x1, y1, x2, y2) format with
304-
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
305-
306-
Returns:
307-
Tensor[N]: the area for each box
284+
Tensor[N]: Tensor containing the area for each box.
308285
"""
309286
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
310-
_log_api_usage_once(box_area_xyxy)
287+
_log_api_usage_once(box_area)
288+
allowed_fmts = (
289+
"xyxy",
290+
"xywh",
291+
"cxcywh",
292+
)
293+
if fmt not in allowed_fmts:
294+
raise ValueError(f"Unsupported Bounding Box area for given fmt {fmt}")
311295
boxes = _upcast(boxes)
312-
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
313-
314-
315-
def box_area_cxcywh(boxes: Tensor) -> Tensor:
316-
"""
317-
Computes the area of a set of bounding boxes, which are specified by their
318-
(cx, cy, w, h) coordinates.
319-
320-
Args:
321-
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
322-
are expected to be in (cx, cy, w, h) format with
323-
``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.
296+
if fmt == "xyxy":
297+
area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
298+
else:
299+
# For formats with width and height, area = width * height
300+
# Supported: cxcywh, xywh
301+
area = boxes[:, 2] * boxes[:, 3]
324302

325-
Returns:
326-
Tensor[N]: the area for each box
327-
"""
328-
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
329-
_log_api_usage_once(box_area_cxcywh)
330-
boxes = _upcast(boxes)
331-
return boxes[:, 2] * boxes[:, 3]
303+
return area
332304

333305

334306
def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
335307
"""
336308
Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
337309
338310
Args:
339-
boxes1 (Tensor[N, 4]): first set of boxes
340-
boxes2 (Tensor[M, 4]): second set of boxes
341-
fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy"
342-
343-
344-
Returns:
345-
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
346-
"""
347-
if fmt == "xyxy":
348-
iou = box_iou_xyxy(boxes1=boxes1, boxes2=boxes2)
349-
elif fmt == "cxcywh":
350-
iou = box_iou_cxcywh(boxes1=boxes1, boxes2=boxes2)
351-
else:
352-
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}")
353-
354-
return iou
355-
356-
357-
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
358-
# with slight modifications
359-
def _box_inter_union_xyxy(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
360-
area1 = box_area(boxes1, fmt="xyxy")
361-
area2 = box_area(boxes2, fmt="xyxy")
362-
363-
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
364-
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
311+
boxes1 (Tensor[N, 4]): first set of boxes.
312+
boxes2 (Tensor[M, 4]): second set of boxes.
313+
format (str): Format of the input boxes.
314+
Default is "xyxy" to preserve backward compatibility.
315+
Supported formats are "xyxy", "xywh", and "cxcywh".
365316
366-
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
367-
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
368-
369-
union = area1[:, None] + area2 - inter
370-
371-
return inter, union
372-
373-
374-
def box_iou_xyxy(boxes1: Tensor, boxes2: Tensor) -> Tensor:
375-
"""
376-
Return intersection-over-union (Jaccard index) between two sets of boxes.
377-
378-
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
379-
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
380-
381-
Args:
382-
boxes1 (Tensor[N, 4]): first set of boxes
383-
boxes2 (Tensor[M, 4]): second set of boxes
384317
385318
Returns:
386319
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
387320
"""
388321
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
389-
_log_api_usage_once(box_iou_xyxy)
390-
inter, union = _box_inter_union_xyxy(boxes1, boxes2)
322+
_log_api_usage_once(box_iou)
323+
allowed_fmts = (
324+
"xyxy",
325+
"xywh",
326+
"cxcywh",
327+
)
328+
if fmt not in allowed_fmts:
329+
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.")
330+
inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt)
391331
iou = inter / union
392332
return iou
393333

394334

395-
def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
396-
area1 = box_area(boxes1, fmt="cxcywh")
397-
area2 = box_area(boxes2, fmt="cxcywh")
335+
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
336+
# with slight modifications
337+
def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]:
338+
area1 = box_area(boxes1, fmt=fmt)
339+
area2 = box_area(boxes2, fmt=fmt)
398340

399-
lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2]
400-
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]
341+
allowed_fmts = (
342+
"xyxy",
343+
"xywh",
344+
"cxcywh",
345+
)
346+
if fmt not in allowed_fmts:
347+
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.")
401348

349+
if fmt == "xyxy":
350+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
351+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
352+
elif fmt == "xywh":
353+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
354+
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]
355+
else: # fmt == "cxcywh":
356+
lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2]
357+
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]
358+
402359
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
403360
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
404361

@@ -407,27 +364,6 @@ def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Ten
407364
return inter, union
408365

409366

410-
def box_iou_cxcywh(boxes1: Tensor, boxes2: Tensor) -> Tensor:
411-
"""
412-
Return intersection-over-union (Jaccard index) between two sets of boxes.
413-
414-
Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with
415-
``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.
416-
417-
Args:
418-
boxes1 (Tensor[N, 4]): first set of boxes
419-
boxes2 (Tensor[M, 4]): second set of boxes
420-
421-
Returns:
422-
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
423-
"""
424-
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
425-
_log_api_usage_once(box_iou_cxcywh)
426-
inter, union = _box_inter_union_cxcywh(boxes1, boxes2)
427-
iou = inter / union
428-
return iou
429-
430-
431367
# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
432368
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
433369
"""
@@ -447,7 +383,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
447383
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
448384
_log_api_usage_once(generalized_box_iou)
449385

450-
inter, union = _box_inter_union_xyxy(boxes1, boxes2)
386+
inter, union = _box_inter_union(boxes1, boxes2)
451387
iou = inter / union
452388

453389
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])

0 commit comments

Comments
 (0)