Skip to content

Commit dffd5ae

Browse files
Update sanitize_bounding_boxes for rotated boxes
1 parent 72b33b0 commit dffd5ae

File tree

1 file changed

+15
-4
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+15
-4
lines changed

torchvision/transforms/v2/functional/_misc.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,16 +405,27 @@ def _get_sanitize_bounding_boxes_mask(
405405
min_area: float = 1.0,
406406
) -> torch.Tensor:
407407

408-
bounding_boxes = _convert_bounding_box_format(
409-
bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format
410-
)
408+
is_rotated = tv_tensors.is_rotated_bounding_format(format)
409+
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
410+
bounding_boxes = _convert_bounding_box_format(bounding_boxes, new_format=intermediate_format, old_format=format)
411411

412412
image_h, image_w = canvas_size
413-
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1]
413+
if is_rotated:
414+
dx12 = bounding_boxes[..., 0] - bounding_boxes[..., 2]
415+
dy12 = bounding_boxes[..., 1] - bounding_boxes[..., 3]
416+
dx23 = bounding_boxes[..., 3] - bounding_boxes[..., 5]
417+
dy23 = bounding_boxes[..., 4] - bounding_boxes[..., 6]
418+
ws = torch.sqrt(dx12**2 + dy12**2)
419+
hs = torch.sqrt(dx23**2 + dy23**2)
420+
else:
421+
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1]
414422
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) & (ws * hs >= min_area)
415423
# TODO: Do we really need to check for out of bounds here? All
416424
# transforms should be clamping anyway, so this should never happen?
417425
image_h, image_w = canvas_size
418426
valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w)
419427
valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h)
428+
if is_rotated:
429+
valid &= (bounding_boxes[..., 4] <= image_w) & (bounding_boxes[..., 5] <= image_h)
430+
valid &= (bounding_boxes[..., 6] <= image_w) & (bounding_boxes[..., 7] <= image_h)
420431
return valid

0 commit comments

Comments
 (0)