From 7c04dbae3a778db814ba20bb757d2ef035cb48d1 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 12:44:48 -0700 Subject: [PATCH 1/3] upcast rotated box transforms --- .../transforms/v2/functional/_geometry.py | 10 ++++++ torchvision/transforms/v2/functional/_meta.py | 34 +++---------------- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 1c9ce3f6df0..f1ace2a09cb 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -451,6 +451,12 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso torch.Tensor: Tensor of same shape as input containing the rectangle coordinates. The output maintains the same dtype as the input. """ + dtype = parallelogram.dtype + acceptable_dtypes = [torch.float32] + need_cast = dtype not in acceptable_dtypes + if need_cast: + # Up-case to avoid overflow for square operations + parallelogram = parallelogram.to(torch.float32) out_boxes = parallelogram.clone() # Calculate parallelogram diagonal vectors @@ -489,6 +495,10 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1]) out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4]) out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5]) + + if need_cast: + out_boxes = out_boxes.to(dtype) + return out_boxes diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6256a288203..15a39b99a54 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -194,11 +194,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: cxcywhr = cxcywhr.clone() - dtype = cxcywhr.dtype - need_cast = not cxcywhr.is_floating_point() - if need_cast: - cxcywhr = cxcywhr.float() - half_wh = cxcywhr[..., 2:-1].div(-2, rounding_mode=None if cxcywhr.is_floating_point() else "floor").abs_() r_rad = cxcywhr[..., 4].mul(torch.pi).div(180.0) cos, sin = r_rad.cos(), r_rad.sin() @@ -207,10 +202,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor: # (cy + width / 2 * sin - height / 2 * cos) = y1 cxcywhr[..., 1].add_(half_wh[..., 0].mul(sin)).sub_(half_wh[..., 1].mul(cos)) - if need_cast: - cxcywhr.round_() - cxcywhr = cxcywhr.to(dtype) - return cxcywhr @@ -218,11 +209,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: xywhr = xywhr.clone() - dtype = xywhr.dtype - need_cast = not xywhr.is_floating_point() - if need_cast: - xywhr = xywhr.float() - half_wh = xywhr[..., 2:-1].div(-2, rounding_mode=None if xywhr.is_floating_point() else "floor").abs_() r_rad = xywhr[..., 4].mul(torch.pi).div(180.0) cos, sin = r_rad.cos(), r_rad.sin() @@ -231,10 +217,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: # (y1 - width / 2 * sin + height / 2 * cos) = cy xywhr[..., 1].sub_(half_wh[..., 0].mul(sin)).add_(half_wh[..., 1].mul(cos)) - if need_cast: - xywhr.round_() - xywhr = xywhr.to(dtype) - return xywhr @@ -243,11 +225,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: xywhr = xywhr.clone() - dtype = xywhr.dtype - need_cast = not xywhr.is_floating_point() - if need_cast: - xywhr = xywhr.float() - wh = xywhr[..., 2:-1] r_rad = xywhr[..., 4].mul(torch.pi).div(180.0) cos, sin = r_rad.cos(), r_rad.sin() @@ -265,10 +242,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: # y1 + h * cos = y4 xywhr[..., 7].add_(wh[..., 1].mul(cos)) - if need_cast: - xywhr.round_() - xywhr = xywhr.to(dtype) - return xywhr @@ -278,7 +251,11 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: xyxyxyxy = xyxyxyxy.clone() dtype = xyxyxyxy.dtype - need_cast = not xyxyxyxy.is_floating_point() + acceptable_dtypes = [torch.float32] # Ensure consistency between CPU and GPU. + need_cast = dtype not in acceptable_dtypes + if need_cast: + # Up-case to avoid overflow for square operations + xyxyxyxy = xyxyxyxy.to(torch.float32) if need_cast: xyxyxyxy = xyxyxyxy.float() @@ -293,7 +270,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0) if need_cast: - xyxyxyxy.round_() xyxyxyxy = xyxyxyxy.to(dtype) return xyxyxyxy[..., :5] From 0ff471684b9077bb3de823fbff3047618d65f4de Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Mon, 11 Aug 2025 13:02:53 -0700 Subject: [PATCH 2/3] Remove double lines for casting --- torchvision/transforms/v2/functional/_meta.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 15a39b99a54..0701e36dca7 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -256,8 +256,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: if need_cast: # Up-case to avoid overflow for square operations xyxyxyxy = xyxyxyxy.to(torch.float32) - if need_cast: - xyxyxyxy = xyxyxyxy.float() r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0])) # x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4 From e3ae3c9dba1d6dab2bc78381184d95264bc65ea2 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Thu, 14 Aug 2025 12:26:16 -0500 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Nicolas Hug --- torchvision/transforms/v2/functional/_geometry.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index f1ace2a09cb..6bfdf43fed6 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -452,7 +452,7 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso The output maintains the same dtype as the input. """ dtype = parallelogram.dtype - acceptable_dtypes = [torch.float32] + acceptable_dtypes = [torch.float32, torch.float64] need_cast = dtype not in acceptable_dtypes if need_cast: # Up-case to avoid overflow for square operations diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 0701e36dca7..4568b39ab59 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -251,7 +251,7 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: xyxyxyxy = xyxyxyxy.clone() dtype = xyxyxyxy.dtype - acceptable_dtypes = [torch.float32] # Ensure consistency between CPU and GPU. + acceptable_dtypes = [torch.float32, torch.float64] # Ensure consistency between CPU and GPU. need_cast = dtype not in acceptable_dtypes if need_cast: # Up-case to avoid overflow for square operations