@@ -130,7 +130,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
130
130
the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead.
131
131
132
132
Args:
133
- boxes (Tensor[N , 4]): boxes in ``(x1, y1, x2, y2)`` format
133
+ boxes (Tensor[... , 4]): boxes in ``(x1, y1, x2, y2)`` format
134
134
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
135
135
min_size (float): minimum size
136
136
@@ -140,7 +140,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
140
140
"""
141
141
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
142
142
_log_api_usage_once (remove_small_boxes )
143
- ws , hs = boxes [: , 2 ] - boxes [: , 0 ], boxes [: , 3 ] - boxes [: , 1 ]
143
+ ws , hs = boxes [... , 2 ] - boxes [... , 0 ], boxes [... , 3 ] - boxes [... , 1 ]
144
144
keep = (ws >= min_size ) & (hs >= min_size )
145
145
keep = torch .where (keep )[0 ]
146
146
return keep
@@ -155,12 +155,12 @@ def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor:
155
155
the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead.
156
156
157
157
Args:
158
- boxes (Tensor[N , 4]): boxes in ``(x1, y1, x2, y2)`` format
158
+ boxes (Tensor[... , 4]): boxes in ``(x1, y1, x2, y2)`` format
159
159
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
160
160
size (Tuple[height, width]): size of the image
161
161
162
162
Returns:
163
- Tensor[N , 4]: clipped boxes
163
+ Tensor[... , 4]: clipped boxes
164
164
"""
165
165
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
166
166
_log_api_usage_once (clip_boxes_to_image )
@@ -276,7 +276,7 @@ def box_area(boxes: Tensor) -> Tensor:
276
276
(x1, y1, x2, y2) coordinates.
277
277
278
278
Args:
279
- boxes (Tensor[N , 4]): boxes for which the area will be computed. They
279
+ boxes (Tensor[... , 4]): boxes for which the area will be computed. They
280
280
are expected to be in (x1, y1, x2, y2) format with
281
281
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
282
282
@@ -286,7 +286,7 @@ def box_area(boxes: Tensor) -> Tensor:
286
286
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
287
287
_log_api_usage_once (box_area )
288
288
boxes = _upcast (boxes )
289
- return (boxes [: , 2 ] - boxes [: , 0 ]) * (boxes [: , 3 ] - boxes [: , 1 ])
289
+ return (boxes [... , 2 ] - boxes [... , 0 ]) * (boxes [... , 3 ] - boxes [... , 1 ])
290
290
291
291
292
292
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
@@ -295,13 +295,13 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
295
295
area1 = box_area (boxes1 )
296
296
area2 = box_area (boxes2 )
297
297
298
- lt = torch .max (boxes1 [: , None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
299
- rb = torch .min (boxes1 [: , None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
298
+ lt = torch .max (boxes1 [... , None , :2 ], boxes2 [..., None , :, :2 ]) # [..., N,M,2]
299
+ rb = torch .min (boxes1 [... , None , 2 :], boxes2 [..., None , :, 2 :]) # [..., N,M,2]
300
300
301
301
wh = _upcast (rb - lt ).clamp (min = 0 ) # [N,M,2]
302
- inter = wh [:, :, 0 ] * wh [:, : , 1 ] # [N,M]
302
+ inter = wh [..., 0 ] * wh [... , 1 ] # [N,M]
303
303
304
- union = area1 [: , None ] + area2 - inter
304
+ union = area1 [... , None ] + area2 [..., None , :] - inter
305
305
306
306
return inter , union
307
307
@@ -314,11 +314,12 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
314
314
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
315
315
316
316
Args:
317
- boxes1 (Tensor[N, 4]): first set of boxes
318
- boxes2 (Tensor[M, 4]): second set of boxes
317
+ boxes1 (Tensor[..., N, 4]): first set of boxes
318
+ boxes2 (Tensor[..., M, 4]): second set of boxes
319
319
320
320
Returns:
321
- Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
321
+ Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
322
+ in boxes1 and boxes2
322
323
"""
323
324
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
324
325
_log_api_usage_once (box_iou )
@@ -336,11 +337,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
336
337
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
337
338
338
339
Args:
339
- boxes1 (Tensor[N, 4]): first set of boxes
340
- boxes2 (Tensor[M, 4]): second set of boxes
340
+ boxes1 (Tensor[..., N, 4]): first set of boxes
341
+ boxes2 (Tensor[..., M, 4]): second set of boxes
341
342
342
343
Returns:
343
- Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
344
+ Tensor[..., N, M]: the NxM matrix containing the pairwise generalized IoU values
344
345
for every element in boxes1 and boxes2
345
346
"""
346
347
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
@@ -349,11 +350,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
349
350
inter , union = _box_inter_union (boxes1 , boxes2 )
350
351
iou = inter / union
351
352
352
- lti = torch .min (boxes1 [: , None , :2 ], boxes2 [:, :2 ])
353
- rbi = torch .max (boxes1 [: , None , 2 :], boxes2 [:, 2 :])
353
+ lti = torch .min (boxes1 [... , None , :2 ], boxes2 [..., None , :, :2 ])
354
+ rbi = torch .max (boxes1 [... , None , 2 :], boxes2 [..., None , :, 2 :])
354
355
355
356
whi = _upcast (rbi - lti ).clamp (min = 0 ) # [N,M,2]
356
- areai = whi [:, :, 0 ] * whi [:, : , 1 ]
357
+ areai = whi [..., 0 ] * whi [... , 1 ]
357
358
358
359
return iou - (areai - union ) / areai
359
360
@@ -364,11 +365,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
364
365
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
365
366
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
366
367
Args:
367
- boxes1 (Tensor[N, 4]): first set of boxes
368
- boxes2 (Tensor[M, 4]): second set of boxes
368
+ boxes1 (Tensor[..., N, 4]): first set of boxes
369
+ boxes2 (Tensor[..., M, 4]): second set of boxes
369
370
eps (float, optional): small number to prevent division by zero. Default: 1e-7
370
371
Returns:
371
- Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
372
+ Tensor[..., N, M]: the NxM matrix containing the pairwise complete IoU values
372
373
for every element in boxes1 and boxes2
373
374
"""
374
375
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
@@ -379,11 +380,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
379
380
380
381
diou , iou = _box_diou_iou (boxes1 , boxes2 , eps )
381
382
382
- w_pred = boxes1 [: , None , 2 ] - boxes1 [: , None , 0 ]
383
- h_pred = boxes1 [: , None , 3 ] - boxes1 [: , None , 1 ]
383
+ w_pred = boxes1 [... , None , 2 ] - boxes1 [... , None , 0 ]
384
+ h_pred = boxes1 [... , None , 3 ] - boxes1 [... , None , 1 ]
384
385
385
- w_gt = boxes2 [:, 2 ] - boxes2 [:, 0 ]
386
- h_gt = boxes2 [:, 3 ] - boxes2 [:, 1 ]
386
+ w_gt = boxes2 [..., None , :, 2 ] - boxes2 [..., None , :, 0 ]
387
+ h_gt = boxes2 [..., None , :, 3 ] - boxes2 [..., None , :, 1 ]
387
388
388
389
v = (4 / (torch .pi ** 2 )) * torch .pow (torch .atan (w_pred / h_pred ) - torch .atan (w_gt / h_gt ), 2 )
389
390
with torch .no_grad ():
@@ -399,12 +400,12 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
399
400
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
400
401
401
402
Args:
402
- boxes1 (Tensor[N, 4]): first set of boxes
403
- boxes2 (Tensor[M, 4]): second set of boxes
403
+ boxes1 (Tensor[..., N, 4]): first set of boxes
404
+ boxes2 (Tensor[..., M, 4]): second set of boxes
404
405
eps (float, optional): small number to prevent division by zero. Default: 1e-7
405
406
406
407
Returns:
407
- Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
408
+ Tensor[..., N, M]: the NxM matrix containing the pairwise distance IoU values
408
409
for every element in boxes1 and boxes2
409
410
"""
410
411
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
@@ -419,17 +420,19 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
419
420
def _box_diou_iou (boxes1 : Tensor , boxes2 : Tensor , eps : float = 1e-7 ) -> tuple [Tensor , Tensor ]:
420
421
421
422
iou = box_iou (boxes1 , boxes2 )
422
- lti = torch .min (boxes1 [: , None , :2 ], boxes2 [:, :2 ])
423
- rbi = torch .max (boxes1 [: , None , 2 :], boxes2 [:, 2 :])
423
+ lti = torch .min (boxes1 [... , None , :2 ], boxes2 [..., None , :, :2 ])
424
+ rbi = torch .max (boxes1 [... , None , 2 :], boxes2 [..., None , :, 2 :])
424
425
whi = _upcast (rbi - lti ).clamp (min = 0 ) # [N,M,2]
425
- diagonal_distance_squared = (whi [:, :, 0 ] ** 2 ) + (whi [:, : , 1 ] ** 2 ) + eps
426
+ diagonal_distance_squared = (whi [..., 0 ] ** 2 ) + (whi [... , 1 ] ** 2 ) + eps
426
427
# centers of boxes
427
- x_p = (boxes1 [: , 0 ] + boxes1 [: , 2 ]) / 2
428
- y_p = (boxes1 [: , 1 ] + boxes1 [: , 3 ]) / 2
429
- x_g = (boxes2 [: , 0 ] + boxes2 [: , 2 ]) / 2
430
- y_g = (boxes2 [: , 1 ] + boxes2 [: , 3 ]) / 2
428
+ x_p = (boxes1 [... , 0 ] + boxes1 [... , 2 ]) / 2
429
+ y_p = (boxes1 [... , 1 ] + boxes1 [... , 3 ]) / 2
430
+ x_g = (boxes2 [... , 0 ] + boxes2 [... , 2 ]) / 2
431
+ y_g = (boxes2 [... , 1 ] + boxes2 [... , 3 ]) / 2
431
432
# The distance between boxes' centers squared.
432
- centers_distance_squared = (_upcast (x_p [:, None ] - x_g [None , :]) ** 2 ) + (_upcast (y_p [:, None ] - y_g [None , :]) ** 2 )
433
+ centers_distance_squared = (_upcast (x_p [..., None ] - x_g [..., None , :]) ** 2 ) + (
434
+ _upcast (y_p [..., None ] - y_g [..., None , :]) ** 2
435
+ )
433
436
# The distance IoU is the IoU penalized by a normalized
434
437
# distance between boxes' centers squared.
435
438
return iou - (centers_distance_squared / diagonal_distance_squared ), iou
0 commit comments