Skip to content

Upcast rotated box transforms #9175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,12 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
The output maintains the same dtype as the input.
"""
dtype = parallelogram.dtype
acceptable_dtypes = [torch.float32, torch.float64]
need_cast = dtype not in acceptable_dtypes
if need_cast:
# Up-case to avoid overflow for square operations
parallelogram = parallelogram.to(torch.float32)
out_boxes = parallelogram.clone()

# Calculate parallelogram diagonal vectors
Expand Down Expand Up @@ -489,6 +495,10 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1])
out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4])
out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5])

if need_cast:
out_boxes = out_boxes.to(dtype)

return out_boxes


Expand Down
34 changes: 4 additions & 30 deletions torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
if not inplace:
cxcywhr = cxcywhr.clone()

dtype = cxcywhr.dtype
need_cast = not cxcywhr.is_floating_point()
if need_cast:
cxcywhr = cxcywhr.float()

half_wh = cxcywhr[..., 2:-1].div(-2, rounding_mode=None if cxcywhr.is_floating_point() else "floor").abs_()
r_rad = cxcywhr[..., 4].mul(torch.pi).div(180.0)
cos, sin = r_rad.cos(), r_rad.sin()
Expand All @@ -207,22 +202,13 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
# (cy + width / 2 * sin - height / 2 * cos) = y1
cxcywhr[..., 1].add_(half_wh[..., 0].mul(sin)).sub_(half_wh[..., 1].mul(cos))

if need_cast:
cxcywhr.round_()
cxcywhr = cxcywhr.to(dtype)

return cxcywhr


def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
if not inplace:
xywhr = xywhr.clone()

dtype = xywhr.dtype
need_cast = not xywhr.is_floating_point()
if need_cast:
xywhr = xywhr.float()

half_wh = xywhr[..., 2:-1].div(-2, rounding_mode=None if xywhr.is_floating_point() else "floor").abs_()
r_rad = xywhr[..., 4].mul(torch.pi).div(180.0)
cos, sin = r_rad.cos(), r_rad.sin()
Expand All @@ -231,10 +217,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
# (y1 - width / 2 * sin + height / 2 * cos) = cy
xywhr[..., 1].sub_(half_wh[..., 0].mul(sin)).add_(half_wh[..., 1].mul(cos))

if need_cast:
xywhr.round_()
xywhr = xywhr.to(dtype)

return xywhr


Expand All @@ -243,11 +225,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
if not inplace:
xywhr = xywhr.clone()

dtype = xywhr.dtype
need_cast = not xywhr.is_floating_point()
if need_cast:
xywhr = xywhr.float()

wh = xywhr[..., 2:-1]
r_rad = xywhr[..., 4].mul(torch.pi).div(180.0)
cos, sin = r_rad.cos(), r_rad.sin()
Expand All @@ -265,10 +242,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
# y1 + h * cos = y4
xywhr[..., 7].add_(wh[..., 1].mul(cos))

if need_cast:
xywhr.round_()
xywhr = xywhr.to(dtype)

return xywhr


Expand All @@ -278,9 +251,11 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
xyxyxyxy = xyxyxyxy.clone()

dtype = xyxyxyxy.dtype
need_cast = not xyxyxyxy.is_floating_point()
acceptable_dtypes = [torch.float32, torch.float64] # Ensure consistency between CPU and GPU.
need_cast = dtype not in acceptable_dtypes
if need_cast:
xyxyxyxy = xyxyxyxy.float()
# Up-case to avoid overflow for square operations
xyxyxyxy = xyxyxyxy.to(torch.float32)

r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0]))
# x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4
Expand All @@ -293,7 +268,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0)

if need_cast:
xyxyxyxy.round_()
xyxyxyxy = xyxyxyxy.to(dtype)

return xyxyxyxy[..., :5]
Expand Down
Loading