@@ -130,7 +130,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
130130        the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead. 
131131
132132    Args: 
133-         boxes (Tensor[N , 4]): boxes in ``(x1, y1, x2, y2)`` format 
133+         boxes (Tensor[... , 4]): boxes in ``(x1, y1, x2, y2)`` format 
134134            with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 
135135        min_size (float): minimum size 
136136
@@ -140,7 +140,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
140140    """ 
141141    if  not  torch .jit .is_scripting () and  not  torch .jit .is_tracing ():
142142        _log_api_usage_once (remove_small_boxes )
143-     ws , hs  =  boxes [: , 2 ] -  boxes [: , 0 ], boxes [: , 3 ] -  boxes [: , 1 ]
143+     ws , hs  =  boxes [... , 2 ] -  boxes [... , 0 ], boxes [... , 3 ] -  boxes [... , 1 ]
144144    keep  =  (ws  >=  min_size ) &  (hs  >=  min_size )
145145    keep  =  torch .where (keep )[0 ]
146146    return  keep 
@@ -155,12 +155,12 @@ def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor:
155155        the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead. 
156156
157157    Args: 
158-         boxes (Tensor[N , 4]): boxes in ``(x1, y1, x2, y2)`` format 
158+         boxes (Tensor[... , 4]): boxes in ``(x1, y1, x2, y2)`` format 
159159            with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 
160160        size (Tuple[height, width]): size of the image 
161161
162162    Returns: 
163-         Tensor[N , 4]: clipped boxes 
163+         Tensor[... , 4]: clipped boxes 
164164    """ 
165165    if  not  torch .jit .is_scripting () and  not  torch .jit .is_tracing ():
166166        _log_api_usage_once (clip_boxes_to_image )
@@ -276,7 +276,7 @@ def box_area(boxes: Tensor) -> Tensor:
276276    (x1, y1, x2, y2) coordinates. 
277277
278278    Args: 
279-         boxes (Tensor[N , 4]): boxes for which the area will be computed. They 
279+         boxes (Tensor[... , 4]): boxes for which the area will be computed. They 
280280            are expected to be in (x1, y1, x2, y2) format with 
281281            ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 
282282
@@ -286,7 +286,7 @@ def box_area(boxes: Tensor) -> Tensor:
286286    if  not  torch .jit .is_scripting () and  not  torch .jit .is_tracing ():
287287        _log_api_usage_once (box_area )
288288    boxes  =  _upcast (boxes )
289-     return  (boxes [: , 2 ] -  boxes [: , 0 ]) *  (boxes [: , 3 ] -  boxes [: , 1 ])
289+     return  (boxes [... , 2 ] -  boxes [... , 0 ]) *  (boxes [... , 3 ] -  boxes [... , 1 ])
290290
291291
292292# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py 
@@ -295,13 +295,13 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
295295    area1  =  box_area (boxes1 )
296296    area2  =  box_area (boxes2 )
297297
298-     lt  =  torch .max (boxes1 [: , None , :2 ], boxes2 [:, :2 ])  # [N,M,2] 
299-     rb  =  torch .min (boxes1 [: , None , 2 :], boxes2 [:, 2 :])  # [N,M,2] 
298+     lt  =  torch .max (boxes1 [... , None , :2 ], boxes2 [...,  None ,  :, :2 ])  # [..., N,M,2] 
299+     rb  =  torch .min (boxes1 [... , None , 2 :], boxes2 [...,  None ,  :, 2 :])  # [..., N,M,2] 
300300
301301    wh  =  _upcast (rb  -  lt ).clamp (min = 0 )  # [N,M,2] 
302-     inter  =  wh [:, :,  0 ] *  wh [:, : , 1 ]  # [N,M] 
302+     inter  =  wh [...,  0 ] *  wh [... , 1 ]  # [N,M] 
303303
304-     union  =  area1 [: , None ] +  area2  -  inter 
304+     union  =  area1 [... , None ] +  area2 [...,  None , :]  -  inter 
305305
306306    return  inter , union 
307307
@@ -314,11 +314,12 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
314314    ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 
315315
316316    Args: 
317-         boxes1 (Tensor[N, 4]): first set of boxes 
318-         boxes2 (Tensor[M, 4]): second set of boxes 
317+         boxes1 (Tensor[...,  N, 4]): first set of boxes 
318+         boxes2 (Tensor[...,  M, 4]): second set of boxes 
319319
320320    Returns: 
321-         Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 
321+         Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element 
322+         in boxes1 and boxes2 
322323    """ 
323324    if  not  torch .jit .is_scripting () and  not  torch .jit .is_tracing ():
324325        _log_api_usage_once (box_iou )
@@ -336,11 +337,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
336337    ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 
337338
338339    Args: 
339-         boxes1 (Tensor[N, 4]): first set of boxes 
340-         boxes2 (Tensor[M, 4]): second set of boxes 
340+         boxes1 (Tensor[...,  N, 4]): first set of boxes 
341+         boxes2 (Tensor[...,  M, 4]): second set of boxes 
341342
342343    Returns: 
343-         Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values 
344+         Tensor[...,  N, M]: the NxM matrix containing the pairwise generalized IoU values 
344345        for every element in boxes1 and boxes2 
345346    """ 
346347    if  not  torch .jit .is_scripting () and  not  torch .jit .is_tracing ():
@@ -349,11 +350,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
349350    inter , union  =  _box_inter_union (boxes1 , boxes2 )
350351    iou  =  inter  /  union 
351352
352-     lti  =  torch .min (boxes1 [: , None , :2 ], boxes2 [:, :2 ])
353-     rbi  =  torch .max (boxes1 [: , None , 2 :], boxes2 [:, 2 :])
353+     lti  =  torch .min (boxes1 [... , None , :2 ], boxes2 [...,  None ,  :, :2 ])
354+     rbi  =  torch .max (boxes1 [... , None , 2 :], boxes2 [...,  None ,  :, 2 :])
354355
355356    whi  =  _upcast (rbi  -  lti ).clamp (min = 0 )  # [N,M,2] 
356-     areai  =  whi [:, :,  0 ] *  whi [:, : , 1 ]
357+     areai  =  whi [...,  0 ] *  whi [... , 1 ]
357358
358359    return  iou  -  (areai  -  union ) /  areai 
359360
@@ -364,11 +365,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
364365    Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with 
365366    ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 
366367    Args: 
367-         boxes1 (Tensor[N, 4]): first set of boxes 
368-         boxes2 (Tensor[M, 4]): second set of boxes 
368+         boxes1 (Tensor[...,  N, 4]): first set of boxes 
369+         boxes2 (Tensor[...,  M, 4]): second set of boxes 
369370        eps (float, optional): small number to prevent division by zero. Default: 1e-7 
370371    Returns: 
371-         Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values 
372+         Tensor[...,  N, M]: the NxM matrix containing the pairwise complete IoU values 
372373        for every element in boxes1 and boxes2 
373374    """ 
374375    if  not  torch .jit .is_scripting () and  not  torch .jit .is_tracing ():
@@ -379,11 +380,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
379380
380381    diou , iou  =  _box_diou_iou (boxes1 , boxes2 , eps )
381382
382-     w_pred  =  boxes1 [: , None , 2 ] -  boxes1 [: , None , 0 ]
383-     h_pred  =  boxes1 [: , None , 3 ] -  boxes1 [: , None , 1 ]
383+     w_pred  =  boxes1 [... , None , 2 ] -  boxes1 [... , None , 0 ]
384+     h_pred  =  boxes1 [... , None , 3 ] -  boxes1 [... , None , 1 ]
384385
385-     w_gt  =  boxes2 [:, 2 ] -  boxes2 [:, 0 ]
386-     h_gt  =  boxes2 [:, 3 ] -  boxes2 [:, 1 ]
386+     w_gt  =  boxes2 [...,  None ,  :, 2 ] -  boxes2 [...,  None ,  :, 0 ]
387+     h_gt  =  boxes2 [...,  None ,  :, 3 ] -  boxes2 [...,  None ,  :, 1 ]
387388
388389    v  =  (4  /  (torch .pi ** 2 )) *  torch .pow (torch .atan (w_pred  /  h_pred ) -  torch .atan (w_gt  /  h_gt ), 2 )
389390    with  torch .no_grad ():
@@ -399,12 +400,12 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
399400    ``0 <= x1 < x2`` and ``0 <= y1 < y2``. 
400401
401402    Args: 
402-         boxes1 (Tensor[N, 4]): first set of boxes 
403-         boxes2 (Tensor[M, 4]): second set of boxes 
403+         boxes1 (Tensor[...,  N, 4]): first set of boxes 
404+         boxes2 (Tensor[...,  M, 4]): second set of boxes 
404405        eps (float, optional): small number to prevent division by zero. Default: 1e-7 
405406
406407    Returns: 
407-         Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values 
408+         Tensor[...,  N, M]: the NxM matrix containing the pairwise distance IoU values 
408409        for every element in boxes1 and boxes2 
409410    """ 
410411    if  not  torch .jit .is_scripting () and  not  torch .jit .is_tracing ():
@@ -419,17 +420,19 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
419420def  _box_diou_iou (boxes1 : Tensor , boxes2 : Tensor , eps : float  =  1e-7 ) ->  tuple [Tensor , Tensor ]:
420421
421422    iou  =  box_iou (boxes1 , boxes2 )
422-     lti  =  torch .min (boxes1 [: , None , :2 ], boxes2 [:, :2 ])
423-     rbi  =  torch .max (boxes1 [: , None , 2 :], boxes2 [:, 2 :])
423+     lti  =  torch .min (boxes1 [... , None , :2 ], boxes2 [...,  None ,  :, :2 ])
424+     rbi  =  torch .max (boxes1 [... , None , 2 :], boxes2 [...,  None ,  :, 2 :])
424425    whi  =  _upcast (rbi  -  lti ).clamp (min = 0 )  # [N,M,2] 
425-     diagonal_distance_squared  =  (whi [:, :,  0 ] **  2 ) +  (whi [:, : , 1 ] **  2 ) +  eps 
426+     diagonal_distance_squared  =  (whi [...,  0 ] **  2 ) +  (whi [... , 1 ] **  2 ) +  eps 
426427    # centers of boxes 
427-     x_p  =  (boxes1 [: , 0 ] +  boxes1 [: , 2 ]) /  2 
428-     y_p  =  (boxes1 [: , 1 ] +  boxes1 [: , 3 ]) /  2 
429-     x_g  =  (boxes2 [: , 0 ] +  boxes2 [: , 2 ]) /  2 
430-     y_g  =  (boxes2 [: , 1 ] +  boxes2 [: , 3 ]) /  2 
428+     x_p  =  (boxes1 [... , 0 ] +  boxes1 [... , 2 ]) /  2 
429+     y_p  =  (boxes1 [... , 1 ] +  boxes1 [... , 3 ]) /  2 
430+     x_g  =  (boxes2 [... , 0 ] +  boxes2 [... , 2 ]) /  2 
431+     y_g  =  (boxes2 [... , 1 ] +  boxes2 [... , 3 ]) /  2 
431432    # The distance between boxes' centers squared. 
432-     centers_distance_squared  =  (_upcast (x_p [:, None ] -  x_g [None , :]) **  2 ) +  (_upcast (y_p [:, None ] -  y_g [None , :]) **  2 )
433+     centers_distance_squared  =  (_upcast (x_p [..., None ] -  x_g [..., None , :]) **  2 ) +  (
434+         _upcast (y_p [..., None ] -  y_g [..., None , :]) **  2 
435+     )
433436    # The distance IoU is the IoU penalized by a normalized 
434437    # distance between boxes' centers squared. 
435438    return  iou  -  (centers_distance_squared  /  diagonal_distance_squared ), iou 
0 commit comments