diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 1c9ce3f6df0..6bfdf43fed6 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -451,6 +451,12 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso torch.Tensor: Tensor of same shape as input containing the rectangle coordinates. The output maintains the same dtype as the input. """ + dtype = parallelogram.dtype + acceptable_dtypes = [torch.float32, torch.float64] + need_cast = dtype not in acceptable_dtypes + if need_cast: + # Up-case to avoid overflow for square operations + parallelogram = parallelogram.to(torch.float32) out_boxes = parallelogram.clone() # Calculate parallelogram diagonal vectors @@ -489,6 +495,10 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1]) out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4]) out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5]) + + if need_cast: + out_boxes = out_boxes.to(dtype) + return out_boxes diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6256a288203..4568b39ab59 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -194,11 +194,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: cxcywhr = cxcywhr.clone() - dtype = cxcywhr.dtype - need_cast = not cxcywhr.is_floating_point() - if need_cast: - cxcywhr = cxcywhr.float() - half_wh = cxcywhr[..., 2:-1].div(-2, rounding_mode=None if cxcywhr.is_floating_point() else "floor").abs_() r_rad = cxcywhr[..., 4].mul(torch.pi).div(180.0) cos, sin = r_rad.cos(), r_rad.sin() @@ -207,10 +202,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor: # (cy + width / 2 * sin - height / 2 * cos) = y1 cxcywhr[..., 1].add_(half_wh[..., 0].mul(sin)).sub_(half_wh[..., 1].mul(cos)) - if need_cast: - cxcywhr.round_() - cxcywhr = cxcywhr.to(dtype) - return cxcywhr @@ -218,11 +209,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: xywhr = xywhr.clone() - dtype = xywhr.dtype - need_cast = not xywhr.is_floating_point() - if need_cast: - xywhr = xywhr.float() - half_wh = xywhr[..., 2:-1].div(-2, rounding_mode=None if xywhr.is_floating_point() else "floor").abs_() r_rad = xywhr[..., 4].mul(torch.pi).div(180.0) cos, sin = r_rad.cos(), r_rad.sin() @@ -231,10 +217,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: # (y1 - width / 2 * sin + height / 2 * cos) = cy xywhr[..., 1].sub_(half_wh[..., 0].mul(sin)).add_(half_wh[..., 1].mul(cos)) - if need_cast: - xywhr.round_() - xywhr = xywhr.to(dtype) - return xywhr @@ -243,11 +225,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: xywhr = xywhr.clone() - dtype = xywhr.dtype - need_cast = not xywhr.is_floating_point() - if need_cast: - xywhr = xywhr.float() - wh = xywhr[..., 2:-1] r_rad = xywhr[..., 4].mul(torch.pi).div(180.0) cos, sin = r_rad.cos(), r_rad.sin() @@ -265,10 +242,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: # y1 + h * cos = y4 xywhr[..., 7].add_(wh[..., 1].mul(cos)) - if need_cast: - xywhr.round_() - xywhr = xywhr.to(dtype) - return xywhr @@ -278,9 +251,11 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: xyxyxyxy = xyxyxyxy.clone() dtype = xyxyxyxy.dtype - need_cast = not xyxyxyxy.is_floating_point() + acceptable_dtypes = [torch.float32, torch.float64] # Ensure consistency between CPU and GPU. + need_cast = dtype not in acceptable_dtypes if need_cast: - xyxyxyxy = xyxyxyxy.float() + # Up-case to avoid overflow for square operations + xyxyxyxy = xyxyxyxy.to(torch.float32) r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0])) # x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4 @@ -293,7 +268,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0) if need_cast: - xyxyxyxy.round_() xyxyxyxy = xyxyxyxy.to(dtype) return xyxyxyxy[..., :5]