Skip to content

Commit 87a238c

Browse files
Modify clamping for rotated boxes
1 parent 9827ab6 commit 87a238c

File tree

1 file changed

+22
-2
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+22
-2
lines changed

torchvision/transforms/v2/functional/_meta.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,16 @@ def _clamp_bounding_boxes(
360360
return out_boxes.to(in_dtype)
361361

362362

363+
def _clamp_rotated_bounding_boxes(
364+
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int]
365+
) -> torch.Tensor:
366+
# TODO: For now we are not clamping rotated bounding boxes.
367+
in_dtype = bounding_boxes.dtype
368+
out_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
369+
370+
return out_boxes.to(in_dtype)
371+
372+
363373
def clamp_bounding_boxes(
364374
inpt: torch.Tensor,
365375
format: Optional[BoundingBoxFormat] = None,
@@ -373,11 +383,21 @@ def clamp_bounding_boxes(
373383

374384
if format is None or canvas_size is None:
375385
raise ValueError("For pure tensor inputs, `format` and `canvas_size` have to be passed.")
376-
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
386+
if tv_tensors.is_rotated_bounding_format(format):
387+
return _clamp_rotated_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
388+
else:
389+
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
377390
elif isinstance(inpt, tv_tensors.BoundingBoxes):
378391
if format is not None or canvas_size is not None:
379392
raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.")
380-
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
393+
if tv_tensors.is_rotated_bounding_format(inpt.format):
394+
output = _clamp_rotated_bounding_boxes(
395+
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
396+
)
397+
else:
398+
output = _clamp_bounding_boxes(
399+
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
400+
)
381401
return tv_tensors.wrap(output, like=inpt)
382402
else:
383403
raise TypeError(

0 commit comments

Comments
 (0)