Skip to content

Commit ed95753

Browse files
Disable torch.float64 precision in tests and transforms
1 parent 1e4d8ae commit ed95753

File tree

3 files changed

+9
-33
lines changed

3 files changed

+9
-33
lines changed

test/test_transforms_v2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,6 @@ def affine_rotated_bounding_boxes(bounding_boxes):
604604
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
605605
)
606606

607-
# For rotated boxes, it is important to cast before clamping.
608607
return (
609608
F.clamp_bounding_boxes(
610609
output.to(dtype=dtype, device=device),
@@ -2021,6 +2020,10 @@ def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
20212020
pytest.xfail("Rotated bounding boxes should be floating point tensors")
20222021

20232022
bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
2023+
if tv_tensors.is_rotated_bounding_format(format):
2024+
# TODO there is a 1e-6 difference between GPU and CPU outputs
2025+
# due to clamping. To avoid failing this test, we do clamp before hand.
2026+
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes)
20242027

20252028
check_kernel(
20262029
F.rotate_bounding_boxes,
@@ -5592,7 +5595,7 @@ def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode,
55925595
boxes = tv_tensors.BoundingBoxes(
55935596
[0, 0, 100, 100, 0], format="XYWHR", canvas_size=(10, 10), clamping_mode=constructor_clamping_mode
55945597
)
5595-
expected_clamped_output = torch.tensor([[0, 0, 10, 10, 0]])
5598+
expected_clamped_output = torch.tensor([[0.0, 0.0, 10.0, 10.0, 0.0]])
55965599
else:
55975600
boxes = tv_tensors.BoundingBoxes(
55985601
[0, 100, 0, 100], format="XYXY", canvas_size=(10, 10), clamping_mode=constructor_clamping_mode

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,10 @@ def horizontal_flip_bounding_boxes(
104104
bounding_boxes[:, 0::2].sub_(canvas_size[1]).neg_()
105105
bounding_boxes = bounding_boxes[:, [2, 3, 0, 1, 6, 7, 4, 5]]
106106
elif format == tv_tensors.BoundingBoxFormat.XYWHR:
107-
108-
dtype = bounding_boxes.dtype
109-
if not torch.is_floating_point(bounding_boxes):
110-
# Casting to float to support cos and sin computations.
111-
bounding_boxes = bounding_boxes.to(torch.float32)
112107
angle_rad = bounding_boxes[:, 4].mul(torch.pi).div(180)
113108
bounding_boxes[:, 0].add_(bounding_boxes[:, 2].mul(angle_rad.cos())).sub_(canvas_size[1]).neg_()
114109
bounding_boxes[:, 1].sub_(bounding_boxes[:, 2].mul(angle_rad.sin()))
115110
bounding_boxes[:, 4].neg_()
116-
bounding_boxes = bounding_boxes.to(dtype)
117111
else: # format == tv_tensors.BoundingBoxFormat.CXCYWHR:
118112
bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
119113
bounding_boxes[:, 4].neg_()
@@ -192,15 +186,10 @@ def vertical_flip_bounding_boxes(
192186
bounding_boxes[:, 1::2].sub_(canvas_size[0]).neg_()
193187
bounding_boxes = bounding_boxes[:, [2, 3, 0, 1, 6, 7, 4, 5]]
194188
elif format == tv_tensors.BoundingBoxFormat.XYWHR:
195-
dtype = bounding_boxes.dtype
196-
if not torch.is_floating_point(bounding_boxes):
197-
# Casting to float to support cos and sin computations.
198-
bounding_boxes = bounding_boxes.to(torch.float64)
199189
angle_rad = bounding_boxes[:, 4].mul(torch.pi).div(180)
200190
bounding_boxes[:, 1].sub_(bounding_boxes[:, 2].mul(angle_rad.sin())).sub_(canvas_size[0]).neg_()
201191
bounding_boxes[:, 0].add_(bounding_boxes[:, 2].mul(angle_rad.cos()))
202192
bounding_boxes[:, 4].neg_().add_(180)
203-
bounding_boxes = bounding_boxes.to(dtype)
204193
else: # format == tv_tensors.BoundingBoxFormat.CXCYWHR:
205194
bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
206195
bounding_boxes[:, 4].neg_().add_(180)
@@ -1102,9 +1091,8 @@ def _affine_bounding_boxes_with_expand(
11021091

11031092
original_shape = bounding_boxes.shape
11041093
dtype = bounding_boxes.dtype
1105-
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
1106-
need_cast = dtype not in acceptable_dtypes
1107-
bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone()
1094+
need_cast = not bounding_boxes.is_floating_point()
1095+
bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone()
11081096
device = bounding_boxes.device
11091097
is_rotated = tv_tensors.is_rotated_bounding_format(format)
11101098
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY

torchvision/transforms/v2/functional/_meta.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,6 @@ def _clamp_along_y_axis(
540540
Returns:
541541
torch.Tensor: The adjusted bounding boxes.
542542
"""
543-
dtype = bounding_boxes.dtype
544-
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
545-
need_cast = dtype not in acceptable_dtypes
546543
original_shape = bounding_boxes.shape
547544
bounding_boxes = bounding_boxes.reshape(-1, 8)
548545
original_bounding_boxes = original_bounding_boxes.reshape(-1, 8)
@@ -561,21 +558,14 @@ def _clamp_along_y_axis(
561558
cond_a = (x1 < 0) & ~case_a.isnan().any(-1) # First point is outside left boundary
562559
cond_b = y1.isclose(y2) | y3.isclose(y4) # First line is nearly vertical
563560
cond_c = (x1 <= 0) & (x2 <= 0) & (x3 <= 0) & (x4 <= 0) # All points outside left boundary
564-
cond_c = (
565-
cond_c
566-
| y1.isclose(y4)
567-
| y2.isclose(y3)
568-
| (cond_b & x1.isclose(x2))
569-
) # First line is nearly horizontal
561+
cond_c = cond_c | y1.isclose(y4) | y2.isclose(y3) | (cond_b & x1.isclose(x2)) # First line is nearly horizontal
570562

571563
for (cond, case) in zip(
572564
[cond_a, cond_b, cond_c],
573565
[case_a, case_b, case_c],
574566
):
575567
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
576568

577-
if need_cast:
578-
bounding_boxes = bounding_boxes.to(dtype)
579569
return bounding_boxes.reshape(original_shape)
580570

581571

@@ -608,10 +598,7 @@ def _clamp_rotated_bounding_boxes(
608598
if clamping_mode is not None and clamping_mode == "none":
609599
return bounding_boxes.clone()
610600
original_shape = bounding_boxes.shape
611-
dtype = bounding_boxes.dtype
612-
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
613-
need_cast = dtype not in acceptable_dtypes
614-
bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone()
601+
bounding_boxes = bounding_boxes.clone()
615602
out_boxes = (
616603
convert_bounding_box_format(
617604
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, inplace=True
@@ -640,8 +627,6 @@ def _clamp_rotated_bounding_boxes(
640627
out_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format, inplace=True
641628
).reshape(original_shape)
642629

643-
if need_cast:
644-
out_boxes = out_boxes.to(dtype)
645630
return out_boxes
646631

647632

0 commit comments

Comments
 (0)