@@ -130,7 +130,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
130130 the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead.
131131
132132 Args:
133- boxes (Tensor[N , 4]): boxes in ``(x1, y1, x2, y2)`` format
133+ boxes (Tensor[... , 4]): boxes in ``(x1, y1, x2, y2)`` format
134134 with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
135135 min_size (float): minimum size
136136
@@ -140,7 +140,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
140140 """
141141 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
142142 _log_api_usage_once (remove_small_boxes )
143- ws , hs = boxes [: , 2 ] - boxes [: , 0 ], boxes [: , 3 ] - boxes [: , 1 ]
143+ ws , hs = boxes [... , 2 ] - boxes [... , 0 ], boxes [... , 3 ] - boxes [... , 1 ]
144144 keep = (ws >= min_size ) & (hs >= min_size )
145145 keep = torch .where (keep )[0 ]
146146 return keep
@@ -155,12 +155,12 @@ def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor:
155155 the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead.
156156
157157 Args:
158- boxes (Tensor[N , 4]): boxes in ``(x1, y1, x2, y2)`` format
158+ boxes (Tensor[... , 4]): boxes in ``(x1, y1, x2, y2)`` format
159159 with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
160160 size (Tuple[height, width]): size of the image
161161
162162 Returns:
163- Tensor[N , 4]: clipped boxes
163+ Tensor[... , 4]: clipped boxes
164164 """
165165 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
166166 _log_api_usage_once (clip_boxes_to_image )
@@ -276,7 +276,7 @@ def box_area(boxes: Tensor) -> Tensor:
276276 (x1, y1, x2, y2) coordinates.
277277
278278 Args:
279- boxes (Tensor[N , 4]): boxes for which the area will be computed. They
279+ boxes (Tensor[... , 4]): boxes for which the area will be computed. They
280280 are expected to be in (x1, y1, x2, y2) format with
281281 ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
282282
@@ -286,7 +286,7 @@ def box_area(boxes: Tensor) -> Tensor:
286286 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
287287 _log_api_usage_once (box_area )
288288 boxes = _upcast (boxes )
289- return (boxes [: , 2 ] - boxes [: , 0 ]) * (boxes [: , 3 ] - boxes [: , 1 ])
289+ return (boxes [... , 2 ] - boxes [... , 0 ]) * (boxes [... , 3 ] - boxes [... , 1 ])
290290
291291
292292# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
@@ -295,13 +295,13 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
295295 area1 = box_area (boxes1 )
296296 area2 = box_area (boxes2 )
297297
298- lt = torch .max (boxes1 [: , None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
299- rb = torch .min (boxes1 [: , None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
298+ lt = torch .max (boxes1 [... , None , :2 ], boxes2 [..., None , :, :2 ]) # [..., N,M,2]
299+ rb = torch .min (boxes1 [... , None , 2 :], boxes2 [..., None , :, 2 :]) # [..., N,M,2]
300300
301301 wh = _upcast (rb - lt ).clamp (min = 0 ) # [N,M,2]
302- inter = wh [:, :, 0 ] * wh [:, : , 1 ] # [N,M]
302+ inter = wh [..., 0 ] * wh [... , 1 ] # [N,M]
303303
304- union = area1 [: , None ] + area2 - inter
304+ union = area1 [... , None ] + area2 [..., None , :] - inter
305305
306306 return inter , union
307307
@@ -314,11 +314,12 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
314314 ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
315315
316316 Args:
317- boxes1 (Tensor[N, 4]): first set of boxes
318- boxes2 (Tensor[M, 4]): second set of boxes
317+ boxes1 (Tensor[..., N, 4]): first set of boxes
318+ boxes2 (Tensor[..., M, 4]): second set of boxes
319319
320320 Returns:
321- Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
321+ Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
322+ in boxes1 and boxes2
322323 """
323324 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
324325 _log_api_usage_once (box_iou )
@@ -336,11 +337,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
336337 ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
337338
338339 Args:
339- boxes1 (Tensor[N, 4]): first set of boxes
340- boxes2 (Tensor[M, 4]): second set of boxes
340+ boxes1 (Tensor[..., N, 4]): first set of boxes
341+ boxes2 (Tensor[..., M, 4]): second set of boxes
341342
342343 Returns:
343- Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
344+ Tensor[..., N, M]: the NxM matrix containing the pairwise generalized IoU values
344345 for every element in boxes1 and boxes2
345346 """
346347 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
@@ -349,11 +350,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
349350 inter , union = _box_inter_union (boxes1 , boxes2 )
350351 iou = inter / union
351352
352- lti = torch .min (boxes1 [: , None , :2 ], boxes2 [:, :2 ])
353- rbi = torch .max (boxes1 [: , None , 2 :], boxes2 [:, 2 :])
353+ lti = torch .min (boxes1 [... , None , :2 ], boxes2 [..., None , :, :2 ])
354+ rbi = torch .max (boxes1 [... , None , 2 :], boxes2 [..., None , :, 2 :])
354355
355356 whi = _upcast (rbi - lti ).clamp (min = 0 ) # [N,M,2]
356- areai = whi [:, :, 0 ] * whi [:, : , 1 ]
357+ areai = whi [..., 0 ] * whi [... , 1 ]
357358
358359 return iou - (areai - union ) / areai
359360
@@ -364,11 +365,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
364365 Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
365366 ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
366367 Args:
367- boxes1 (Tensor[N, 4]): first set of boxes
368- boxes2 (Tensor[M, 4]): second set of boxes
368+ boxes1 (Tensor[..., N, 4]): first set of boxes
369+ boxes2 (Tensor[..., M, 4]): second set of boxes
369370 eps (float, optional): small number to prevent division by zero. Default: 1e-7
370371 Returns:
371- Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
372+ Tensor[..., N, M]: the NxM matrix containing the pairwise complete IoU values
372373 for every element in boxes1 and boxes2
373374 """
374375 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
@@ -379,11 +380,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
379380
380381 diou , iou = _box_diou_iou (boxes1 , boxes2 , eps )
381382
382- w_pred = boxes1 [: , None , 2 ] - boxes1 [: , None , 0 ]
383- h_pred = boxes1 [: , None , 3 ] - boxes1 [: , None , 1 ]
383+ w_pred = boxes1 [... , None , 2 ] - boxes1 [... , None , 0 ]
384+ h_pred = boxes1 [... , None , 3 ] - boxes1 [... , None , 1 ]
384385
385- w_gt = boxes2 [:, 2 ] - boxes2 [:, 0 ]
386- h_gt = boxes2 [:, 3 ] - boxes2 [:, 1 ]
386+ w_gt = boxes2 [..., None , :, 2 ] - boxes2 [..., None , :, 0 ]
387+ h_gt = boxes2 [..., None , :, 3 ] - boxes2 [..., None , :, 1 ]
387388
388389 v = (4 / (torch .pi ** 2 )) * torch .pow (torch .atan (w_pred / h_pred ) - torch .atan (w_gt / h_gt ), 2 )
389390 with torch .no_grad ():
@@ -399,12 +400,12 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
399400 ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
400401
401402 Args:
402- boxes1 (Tensor[N, 4]): first set of boxes
403- boxes2 (Tensor[M, 4]): second set of boxes
403+ boxes1 (Tensor[..., N, 4]): first set of boxes
404+ boxes2 (Tensor[..., M, 4]): second set of boxes
404405 eps (float, optional): small number to prevent division by zero. Default: 1e-7
405406
406407 Returns:
407- Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
408+ Tensor[..., N, M]: the NxM matrix containing the pairwise distance IoU values
408409 for every element in boxes1 and boxes2
409410 """
410411 if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
@@ -419,17 +420,19 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
419420def _box_diou_iou (boxes1 : Tensor , boxes2 : Tensor , eps : float = 1e-7 ) -> tuple [Tensor , Tensor ]:
420421
421422 iou = box_iou (boxes1 , boxes2 )
422- lti = torch .min (boxes1 [: , None , :2 ], boxes2 [:, :2 ])
423- rbi = torch .max (boxes1 [: , None , 2 :], boxes2 [:, 2 :])
423+ lti = torch .min (boxes1 [... , None , :2 ], boxes2 [..., None , :, :2 ])
424+ rbi = torch .max (boxes1 [... , None , 2 :], boxes2 [..., None , :, 2 :])
424425 whi = _upcast (rbi - lti ).clamp (min = 0 ) # [N,M,2]
425- diagonal_distance_squared = (whi [:, :, 0 ] ** 2 ) + (whi [:, : , 1 ] ** 2 ) + eps
426+ diagonal_distance_squared = (whi [..., 0 ] ** 2 ) + (whi [... , 1 ] ** 2 ) + eps
426427 # centers of boxes
427- x_p = (boxes1 [: , 0 ] + boxes1 [: , 2 ]) / 2
428- y_p = (boxes1 [: , 1 ] + boxes1 [: , 3 ]) / 2
429- x_g = (boxes2 [: , 0 ] + boxes2 [: , 2 ]) / 2
430- y_g = (boxes2 [: , 1 ] + boxes2 [: , 3 ]) / 2
428+ x_p = (boxes1 [... , 0 ] + boxes1 [... , 2 ]) / 2
429+ y_p = (boxes1 [... , 1 ] + boxes1 [... , 3 ]) / 2
430+ x_g = (boxes2 [... , 0 ] + boxes2 [... , 2 ]) / 2
431+ y_g = (boxes2 [... , 1 ] + boxes2 [... , 3 ]) / 2
431432 # The distance between boxes' centers squared.
432- centers_distance_squared = (_upcast (x_p [:, None ] - x_g [None , :]) ** 2 ) + (_upcast (y_p [:, None ] - y_g [None , :]) ** 2 )
433+ centers_distance_squared = (_upcast (x_p [..., None ] - x_g [..., None , :]) ** 2 ) + (
434+ _upcast (y_p [..., None ] - y_g [..., None , :]) ** 2
435+ )
433436 # The distance IoU is the IoU penalized by a normalized
434437 # distance between boxes' centers squared.
435438 return iou - (centers_distance_squared / diagonal_distance_squared ), iou
0 commit comments