Skip to content

Commit e3ae3c9

Browse files
Apply suggestions from code review
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 0ff4716 commit e3ae3c9

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
452452
The output maintains the same dtype as the input.
453453
"""
454454
dtype = parallelogram.dtype
455-
acceptable_dtypes = [torch.float32]
455+
acceptable_dtypes = [torch.float32, torch.float64]
456456
need_cast = dtype not in acceptable_dtypes
457457
if need_cast:
458458
# Up-case to avoid overflow for square operations

torchvision/transforms/v2/functional/_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
251251
xyxyxyxy = xyxyxyxy.clone()
252252

253253
dtype = xyxyxyxy.dtype
254-
acceptable_dtypes = [torch.float32] # Ensure consistency between CPU and GPU.
254+
acceptable_dtypes = [torch.float32, torch.float64] # Ensure consistency between CPU and GPU.
255255
need_cast = dtype not in acceptable_dtypes
256256
if need_cast:
257257
# Up-case to avoid overflow for square operations

0 commit comments

Comments
 (0)