@@ -275,130 +275,87 @@ def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
275275 Computes the area of a set of bounding boxes from a given format.
276276
277277 Args:
278- boxes (Tensor[N, 4]): boxes for which the area will be computed. They
279- are expected to be in (x1, y1, x2, y2) format with
280- ``0 <= x1 < x2`` and ``0 <= y1 < y2`` .
281- fmt (str): Format of given boxes. Supported formats are [' xyxy', 'cxcywh']. Default: "xyxy"
278+ boxes (Tensor[N, 4]): Tensor containing N boxes.
279+ format (str): Format of the input boxes.
280+ Default is "xyxy" to preserve backward compatibility .
281+ Supported formats are " xyxy", "xywh", and "cxcywh".
282282
283283 Returns:
284- Tensor[N]: the area for each box
285- """
286- if fmt == "xyxy" :
287- boxes = box_area_xyxy (boxes = boxes )
288- elif fmt == "cxcywh" :
289- boxes = box_area_cxcywh (boxes = boxes )
290- else :
291- raise ValueError (f"Unsupported Box Area Calculation for given fmt { fmt } " )
292-
293- return boxes
294-
295-
296- def box_area_xyxy (boxes : Tensor ) -> Tensor :
297- """
298- Computes the area of a set of bounding boxes, which are specified by their
299- (x1, y1, x2, y2) coordinates.
300-
301- Args:
302- boxes (Tensor[N, 4]): boxes for which the area will be computed. They
303- are expected to be in (x1, y1, x2, y2) format with
304- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
305-
306- Returns:
307- Tensor[N]: the area for each box
284+ Tensor[N]: Tensor containing the area for each box.
308285 """
309286 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
310- _log_api_usage_once (box_area_xyxy )
287+ _log_api_usage_once (box_area )
288+ allowed_fmts = (
289+ "xyxy" ,
290+ "xywh" ,
291+ "cxcywh" ,
292+ )
293+ if fmt not in allowed_fmts :
294+ raise ValueError (f"Unsupported Bounding Box area for given fmt { fmt } " )
311295 boxes = _upcast (boxes )
312- return (boxes [:, 2 ] - boxes [:, 0 ]) * (boxes [:, 3 ] - boxes [:, 1 ])
313-
314-
315- def box_area_cxcywh (boxes : Tensor ) -> Tensor :
316- """
317- Computes the area of a set of bounding boxes, which are specified by their
318- (cx, cy, w, h) coordinates.
319-
320- Args:
321- boxes (Tensor[N, 4]): boxes for which the area will be computed. They
322- are expected to be in (cx, cy, w, h) format with
323- ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.
296+ if fmt == "xyxy" :
297+ area = (boxes [:, 2 ] - boxes [:, 0 ]) * (boxes [:, 3 ] - boxes [:, 1 ])
298+ else :
299+ # For formats with width and height, area = width * height
300+ # Supported: cxcywh, xywh
301+ area = boxes [:, 2 ] * boxes [:, 3 ]
324302
325- Returns:
326- Tensor[N]: the area for each box
327- """
328- if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
329- _log_api_usage_once (box_area_cxcywh )
330- boxes = _upcast (boxes )
331- return boxes [:, 2 ] * boxes [:, 3 ]
303+ return area
332304
333305
334306def box_iou (boxes1 : Tensor , boxes2 : Tensor , fmt : str = "xyxy" ) -> Tensor :
335307 """
336308 Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
337309
338310 Args:
339- boxes1 (Tensor[N, 4]): first set of boxes
340- boxes2 (Tensor[M, 4]): second set of boxes
341- fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy"
342-
343-
344- Returns:
345- Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
346- """
347- if fmt == "xyxy" :
348- iou = box_iou_xyxy (boxes1 = boxes1 , boxes2 = boxes2 )
349- elif fmt == "cxcywh" :
350- iou = box_iou_cxcywh (boxes1 = boxes1 , boxes2 = boxes2 )
351- else :
352- raise ValueError (f"Unsupported Box IoU Calculation for given fmt { fmt } " )
353-
354- return iou
355-
356-
357- # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
358- # with slight modifications
359- def _box_inter_union_xyxy (boxes1 : Tensor , boxes2 : Tensor ) -> tuple [Tensor , Tensor ]:
360- area1 = box_area (boxes1 , fmt = "xyxy" )
361- area2 = box_area (boxes2 , fmt = "xyxy" )
362-
363- lt = torch .max (boxes1 [:, None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
364- rb = torch .min (boxes1 [:, None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
311+ boxes1 (Tensor[N, 4]): first set of boxes.
312+ boxes2 (Tensor[M, 4]): second set of boxes.
313+ format (str): Format of the input boxes.
314+ Default is "xyxy" to preserve backward compatibility.
315+ Supported formats are "xyxy", "xywh", and "cxcywh".
365316
366- wh = _upcast (rb - lt ).clamp (min = 0 ) # [N,M,2]
367- inter = wh [:, :, 0 ] * wh [:, :, 1 ] # [N,M]
368-
369- union = area1 [:, None ] + area2 - inter
370-
371- return inter , union
372-
373-
374- def box_iou_xyxy (boxes1 : Tensor , boxes2 : Tensor ) -> Tensor :
375- """
376- Return intersection-over-union (Jaccard index) between two sets of boxes.
377-
378- Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
379- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
380-
381- Args:
382- boxes1 (Tensor[N, 4]): first set of boxes
383- boxes2 (Tensor[M, 4]): second set of boxes
384317
385318 Returns:
386319 Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
387320 """
388321 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
389- _log_api_usage_once (box_iou_xyxy )
390- inter , union = _box_inter_union_xyxy (boxes1 , boxes2 )
322+ _log_api_usage_once (box_iou )
323+ allowed_fmts = (
324+ "xyxy" ,
325+ "xywh" ,
326+ "cxcywh" ,
327+ )
328+ if fmt not in allowed_fmts :
329+ raise ValueError (f"Unsupported Box IoU Calculation for given fmt { format } ." )
330+ inter , union = _box_inter_union (boxes1 , boxes2 , fmt = fmt )
391331 iou = inter / union
392332 return iou
393333
394334
395- def _box_inter_union_cxcywh (boxes1 : Tensor , boxes2 : Tensor ) -> tuple [Tensor , Tensor ]:
396- area1 = box_area (boxes1 , fmt = "cxcywh" )
397- area2 = box_area (boxes2 , fmt = "cxcywh" )
335+ # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
336+ # with slight modifications
337+ def _box_inter_union (boxes1 : Tensor , boxes2 : Tensor , fmt : str = "xyxy" ) -> tuple [Tensor , Tensor ]:
338+ area1 = box_area (boxes1 , fmt = fmt )
339+ area2 = box_area (boxes2 , fmt = fmt )
398340
399- lt = torch .max (boxes1 [:, None , :2 ] - boxes1 [:, None , 2 :] / 2 , boxes2 [:, :2 ] - boxes2 [:, 2 :] / 2 ) # [N,M,2]
400- rb = torch .min (boxes1 [:, None , :2 ] + boxes1 [:, None , 2 :] / 2 , boxes2 [:, :2 ] + boxes2 [:, 2 :] / 2 ) # [N,M,2]
341+ allowed_fmts = (
342+ "xyxy" ,
343+ "xywh" ,
344+ "cxcywh" ,
345+ )
346+ if fmt not in allowed_fmts :
347+ raise ValueError (f"Unsupported Box IoU Calculation for given fmt { format } ." )
401348
349+ if fmt == "xyxy" :
350+ lt = torch .max (boxes1 [:, None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
351+ rb = torch .min (boxes1 [:, None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
352+ elif fmt == "xywh" :
353+ lt = torch .max (boxes1 [:, None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
354+ rb = torch .min (boxes1 [:, None , :2 ] + boxes1 [:, None , 2 :] / 2 , boxes2 [:, :2 ] + boxes2 [:, 2 :] / 2 ) # [N,M,2]
355+ else : # fmt == "cxcywh":
356+ lt = torch .max (boxes1 [:, None , :2 ] - boxes1 [:, None , 2 :] / 2 , boxes2 [:, :2 ] - boxes2 [:, 2 :] / 2 ) # [N,M,2]
357+ rb = torch .min (boxes1 [:, None , :2 ] + boxes1 [:, None , 2 :] / 2 , boxes2 [:, :2 ] + boxes2 [:, 2 :] / 2 ) # [N,M,2]
358+
402359 wh = _upcast (rb - lt ).clamp (min = 0 ) # [N,M,2]
403360 inter = wh [:, :, 0 ] * wh [:, :, 1 ] # [N,M]
404361
@@ -407,27 +364,6 @@ def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Ten
407364 return inter , union
408365
409366
410- def box_iou_cxcywh (boxes1 : Tensor , boxes2 : Tensor ) -> Tensor :
411- """
412- Return intersection-over-union (Jaccard index) between two sets of boxes.
413-
414- Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with
415- ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.
416-
417- Args:
418- boxes1 (Tensor[N, 4]): first set of boxes
419- boxes2 (Tensor[M, 4]): second set of boxes
420-
421- Returns:
422- Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
423- """
424- if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
425- _log_api_usage_once (box_iou_cxcywh )
426- inter , union = _box_inter_union_cxcywh (boxes1 , boxes2 )
427- iou = inter / union
428- return iou
429-
430-
431367# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
432368def generalized_box_iou (boxes1 : Tensor , boxes2 : Tensor ) -> Tensor :
433369 """
@@ -447,7 +383,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
447383 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
448384 _log_api_usage_once (generalized_box_iou )
449385
450- inter , union = _box_inter_union_xyxy (boxes1 , boxes2 )
386+ inter , union = _box_inter_union (boxes1 , boxes2 )
451387 iou = inter / union
452388
453389 lti = torch .min (boxes1 [:, None , :2 ], boxes2 [:, :2 ])
0 commit comments