@@ -405,16 +405,27 @@ def _get_sanitize_bounding_boxes_mask(
405405 min_area : float = 1.0 ,
406406) -> torch .Tensor :
407407
408- bounding_boxes = _convert_bounding_box_format (
409- bounding_boxes , new_format = tv_tensors .BoundingBoxFormat .XYXY , old_format = format
410- )
408+ is_rotated = tv_tensors . is_rotated_bounding_format ( format )
409+ intermediate_format = tv_tensors . BoundingBoxFormat . XYXYXYXY if is_rotated else tv_tensors .BoundingBoxFormat .XYXY
410+ bounding_boxes = _convert_bounding_box_format ( bounding_boxes , new_format = intermediate_format , old_format = format )
411411
412412 image_h , image_w = canvas_size
413- ws , hs = bounding_boxes [:, 2 ] - bounding_boxes [:, 0 ], bounding_boxes [:, 3 ] - bounding_boxes [:, 1 ]
413+ if is_rotated :
414+ dx12 = bounding_boxes [..., 0 ] - bounding_boxes [..., 2 ]
415+ dy12 = bounding_boxes [..., 1 ] - bounding_boxes [..., 3 ]
416+ dx23 = bounding_boxes [..., 3 ] - bounding_boxes [..., 5 ]
417+ dy23 = bounding_boxes [..., 4 ] - bounding_boxes [..., 6 ]
418+ ws = torch .sqrt (dx12 ** 2 + dy12 ** 2 )
419+ hs = torch .sqrt (dx23 ** 2 + dy23 ** 2 )
420+ else :
421+ ws , hs = bounding_boxes [:, 2 ] - bounding_boxes [:, 0 ], bounding_boxes [:, 3 ] - bounding_boxes [:, 1 ]
414422 valid = (ws >= min_size ) & (hs >= min_size ) & (bounding_boxes >= 0 ).all (dim = - 1 ) & (ws * hs >= min_area )
415423 # TODO: Do we really need to check for out of bounds here? All
416424 # transforms should be clamping anyway, so this should never happen?
417425 image_h , image_w = canvas_size
418426 valid &= (bounding_boxes [:, 0 ] <= image_w ) & (bounding_boxes [:, 2 ] <= image_w )
419427 valid &= (bounding_boxes [:, 1 ] <= image_h ) & (bounding_boxes [:, 3 ] <= image_h )
428+ if is_rotated :
429+ valid &= (bounding_boxes [..., 4 ] <= image_w ) & (bounding_boxes [..., 5 ] <= image_h )
430+ valid &= (bounding_boxes [..., 6 ] <= image_w ) & (bounding_boxes [..., 7 ] <= image_h )
420431 return valid
0 commit comments