Skip to content

Commit adcc878

Browse files
Re-order with original file structure
1 parent 7b6a3ea commit adcc878

File tree

1 file changed

+29
-29
lines changed

1 file changed

+29
-29
lines changed

torchvision/ops/boxes.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -303,35 +303,6 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
303303
return area
304304

305305

306-
def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
307-
"""
308-
Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
309-
310-
Args:
311-
boxes1 (Tensor[..., N, 4]): first set of boxes
312-
boxes2 (Tensor[..., M, 4]): second set of boxes
313-
format (str): Format of the input boxes.
314-
Default is "xyxy" to preserve backward compatibility.
315-
Supported formats are "xyxy", "xywh", and "cxcywh".
316-
317-
Returns:
318-
Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
319-
in boxes1 and boxes2
320-
"""
321-
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
322-
_log_api_usage_once(box_iou)
323-
allowed_fmts = (
324-
"xyxy",
325-
"xywh",
326-
"cxcywh",
327-
)
328-
if fmt not in allowed_fmts:
329-
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.")
330-
inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt)
331-
iou = inter / union
332-
return iou
333-
334-
335306
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
336307
# with slight modifications
337308
def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]:
@@ -364,6 +335,35 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple
364335
return inter, union
365336

366337

338+
def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
339+
"""
340+
Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
341+
342+
Args:
343+
boxes1 (Tensor[..., N, 4]): first set of boxes
344+
boxes2 (Tensor[..., M, 4]): second set of boxes
345+
format (str): Format of the input boxes.
346+
Default is "xyxy" to preserve backward compatibility.
347+
Supported formats are "xyxy", "xywh", and "cxcywh".
348+
349+
Returns:
350+
Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
351+
in boxes1 and boxes2
352+
"""
353+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
354+
_log_api_usage_once(box_iou)
355+
allowed_fmts = (
356+
"xyxy",
357+
"xywh",
358+
"cxcywh",
359+
)
360+
if fmt not in allowed_fmts:
361+
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {format}.")
362+
inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt)
363+
iou = inter / union
364+
return iou
365+
366+
367367
# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
368368
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
369369
"""

0 commit comments

Comments
 (0)