@@ -318,19 +318,25 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple
318318 raise ValueError (f"Unsupported Box IoU Calculation for given fmt { format } ." )
319319
320320 if fmt == "xyxy" :
321- lt = torch .max (boxes1 [: , None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
322- rb = torch .min (boxes1 [: , None , 2 :], boxes2 [:, 2 :]) # [N,M,2]
321+ lt = torch .max (boxes1 [... , None , :2 ], boxes2 [..., None , :, :2 ]) # [..., N,M,2]
322+ rb = torch .min (boxes1 [... , None , 2 :], boxes2 [..., None , :, 2 :]) # [..., N,M,2]
323323 elif fmt == "xywh" :
324- lt = torch .max (boxes1 [:, None , :2 ], boxes2 [:, :2 ]) # [N,M,2]
325- rb = torch .min (boxes1 [:, None , :2 ] + boxes1 [:, None , 2 :], boxes2 [:, :2 ] + boxes2 [:, 2 :]) # [N,M,2]
324+ lt = torch .max (boxes1 [..., None , :2 ], boxes2 [..., None , :, :2 ]) # [...,N,M,2]
325+ rb = torch .min (
326+ boxes1 [..., None , :2 ] + boxes1 [..., None , 2 :], boxes2 [..., None , :2 ] + boxes2 [..., None , 2 :]
327+ ) # [N,M,2]
326328 else : # fmt == "cxcywh":
327- lt = torch .max (boxes1 [:, None , :2 ] - boxes1 [:, None , 2 :] / 2 , boxes2 [:, :2 ] - boxes2 [:, 2 :] / 2 ) # [N,M,2]
328- rb = torch .min (boxes1 [:, None , :2 ] + boxes1 [:, None , 2 :] / 2 , boxes2 [:, :2 ] + boxes2 [:, 2 :] / 2 ) # [N,M,2]
329+ lt = torch .max (
330+ boxes1 [..., None , :2 ] - boxes1 [..., None , 2 :] / 2 , boxes2 [..., None , :2 ] - boxes2 [..., None , 2 :] / 2
331+ ) # [N,M,2]
332+ rb = torch .min (
333+ boxes1 [..., None , :2 ] + boxes1 [..., None , 2 :] / 2 , boxes2 [..., None , :2 ] + boxes2 [..., None , 2 :] / 2
334+ ) # [N,M,2]
329335
330- wh = _upcast (rb - lt ).clamp (min = 0 ) # [N,M,2]
331- inter = wh [:, :, 0 ] * wh [:, :, 1 ] # [N,M]
336+ wh = _upcast (rb - lt ).clamp (min = 0 ) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
337+ inter = wh [..., 0 ] * wh [..., 1 ] # [N,M] inter = wh[ :, :, 0] * wh[:, :, 1] # [N,M]
332338
333- union = area1 [: , None ] + area2 - inter
339+ union = area1 [... , None ] + area2 [..., None , :] - inter
334340
335341 return inter , union
336342
0 commit comments