@@ -303,35 +303,6 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
303303 return area
304304
305305
306- def box_iou (boxes1 : Tensor , boxes2 : Tensor , fmt : str = "xyxy" ) -> Tensor :
307- """
308- Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
309-
310- Args:
311- boxes1 (Tensor[..., N, 4]): first set of boxes
312- boxes2 (Tensor[..., M, 4]): second set of boxes
313- format (str): Format of the input boxes.
314- Default is "xyxy" to preserve backward compatibility.
315- Supported formats are "xyxy", "xywh", and "cxcywh".
316-
317- Returns:
318- Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
319- in boxes1 and boxes2
320- """
321- if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
322- _log_api_usage_once (box_iou )
323- allowed_fmts = (
324- "xyxy" ,
325- "xywh" ,
326- "cxcywh" ,
327- )
328- if fmt not in allowed_fmts :
329- raise ValueError (f"Unsupported Box IoU Calculation for given fmt { format } ." )
330- inter , union = _box_inter_union (boxes1 , boxes2 , fmt = fmt )
331- iou = inter / union
332- return iou
333-
334-
335306# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
336307# with slight modifications
337308def _box_inter_union (boxes1 : Tensor , boxes2 : Tensor , fmt : str = "xyxy" ) -> tuple [Tensor , Tensor ]:
@@ -364,6 +335,35 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple
364335 return inter , union
365336
366337
338+ def box_iou (boxes1 : Tensor , boxes2 : Tensor , fmt : str = "xyxy" ) -> Tensor :
339+ """
340+ Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
341+
342+ Args:
343+ boxes1 (Tensor[..., N, 4]): first set of boxes
344+ boxes2 (Tensor[..., M, 4]): second set of boxes
345+ format (str): Format of the input boxes.
346+ Default is "xyxy" to preserve backward compatibility.
347+ Supported formats are "xyxy", "xywh", and "cxcywh".
348+
349+ Returns:
350+ Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
351+ in boxes1 and boxes2
352+ """
353+ if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
354+ _log_api_usage_once (box_iou )
355+ allowed_fmts = (
356+ "xyxy" ,
357+ "xywh" ,
358+ "cxcywh" ,
359+ )
360+ if fmt not in allowed_fmts :
361+ raise ValueError (f"Unsupported Box IoU Calculation for given fmt { format } ." )
362+ inter , union = _box_inter_union (boxes1 , boxes2 , fmt = fmt )
363+ iou = inter / union
364+ return iou
365+
366+
367367# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
368368def generalized_box_iou (boxes1 : Tensor , boxes2 : Tensor ) -> Tensor :
369369 """
0 commit comments