Skip to content

Commit 8c9aaed

Browse files
Upcast rotated box transforms (#9175)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent c85f008 commit 8c9aaed

File tree

2 files changed

+14
-30
lines changed

2 files changed

+14
-30
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,12 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
451451
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
452452
The output maintains the same dtype as the input.
453453
"""
454+
dtype = parallelogram.dtype
455+
acceptable_dtypes = [torch.float32, torch.float64]
456+
need_cast = dtype not in acceptable_dtypes
457+
if need_cast:
458+
# Up-case to avoid overflow for square operations
459+
parallelogram = parallelogram.to(torch.float32)
454460
out_boxes = parallelogram.clone()
455461

456462
# Calculate parallelogram diagonal vectors
@@ -489,6 +495,10 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
489495
out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1])
490496
out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4])
491497
out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5])
498+
499+
if need_cast:
500+
out_boxes = out_boxes.to(dtype)
501+
492502
return out_boxes
493503

494504

torchvision/transforms/v2/functional/_meta.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
194194
if not inplace:
195195
cxcywhr = cxcywhr.clone()
196196

197-
dtype = cxcywhr.dtype
198-
need_cast = not cxcywhr.is_floating_point()
199-
if need_cast:
200-
cxcywhr = cxcywhr.float()
201-
202197
half_wh = cxcywhr[..., 2:-1].div(-2, rounding_mode=None if cxcywhr.is_floating_point() else "floor").abs_()
203198
r_rad = cxcywhr[..., 4].mul(torch.pi).div(180.0)
204199
cos, sin = r_rad.cos(), r_rad.sin()
@@ -207,22 +202,13 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
207202
# (cy + width / 2 * sin - height / 2 * cos) = y1
208203
cxcywhr[..., 1].add_(half_wh[..., 0].mul(sin)).sub_(half_wh[..., 1].mul(cos))
209204

210-
if need_cast:
211-
cxcywhr.round_()
212-
cxcywhr = cxcywhr.to(dtype)
213-
214205
return cxcywhr
215206

216207

217208
def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
218209
if not inplace:
219210
xywhr = xywhr.clone()
220211

221-
dtype = xywhr.dtype
222-
need_cast = not xywhr.is_floating_point()
223-
if need_cast:
224-
xywhr = xywhr.float()
225-
226212
half_wh = xywhr[..., 2:-1].div(-2, rounding_mode=None if xywhr.is_floating_point() else "floor").abs_()
227213
r_rad = xywhr[..., 4].mul(torch.pi).div(180.0)
228214
cos, sin = r_rad.cos(), r_rad.sin()
@@ -231,10 +217,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
231217
# (y1 - width / 2 * sin + height / 2 * cos) = cy
232218
xywhr[..., 1].sub_(half_wh[..., 0].mul(sin)).add_(half_wh[..., 1].mul(cos))
233219

234-
if need_cast:
235-
xywhr.round_()
236-
xywhr = xywhr.to(dtype)
237-
238220
return xywhr
239221

240222

@@ -243,11 +225,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
243225
if not inplace:
244226
xywhr = xywhr.clone()
245227

246-
dtype = xywhr.dtype
247-
need_cast = not xywhr.is_floating_point()
248-
if need_cast:
249-
xywhr = xywhr.float()
250-
251228
wh = xywhr[..., 2:-1]
252229
r_rad = xywhr[..., 4].mul(torch.pi).div(180.0)
253230
cos, sin = r_rad.cos(), r_rad.sin()
@@ -265,10 +242,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
265242
# y1 + h * cos = y4
266243
xywhr[..., 7].add_(wh[..., 1].mul(cos))
267244

268-
if need_cast:
269-
xywhr.round_()
270-
xywhr = xywhr.to(dtype)
271-
272245
return xywhr
273246

274247

@@ -278,9 +251,11 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
278251
xyxyxyxy = xyxyxyxy.clone()
279252

280253
dtype = xyxyxyxy.dtype
281-
need_cast = not xyxyxyxy.is_floating_point()
254+
acceptable_dtypes = [torch.float32, torch.float64] # Ensure consistency between CPU and GPU.
255+
need_cast = dtype not in acceptable_dtypes
282256
if need_cast:
283-
xyxyxyxy = xyxyxyxy.float()
257+
# Up-case to avoid overflow for square operations
258+
xyxyxyxy = xyxyxyxy.to(torch.float32)
284259

285260
r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0]))
286261
# x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4
@@ -293,7 +268,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
293268
xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0)
294269

295270
if need_cast:
296-
xyxyxyxy.round_()
297271
xyxyxyxy = xyxyxyxy.to(dtype)
298272

299273
return xyxyxyxy[..., :5]

0 commit comments

Comments
 (0)