Skip to content

Commit 1e4d8ae

Browse files
remove eps in _clamp_along_y_axis
1 parent e417da1 commit 1e4d8ae

File tree

1 file changed

+5
-8
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+5
-8
lines changed

torchvision/transforms/v2/functional/_meta.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,6 @@ def _clamp_along_y_axis(
543543
dtype = bounding_boxes.dtype
544544
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
545545
need_cast = dtype not in acceptable_dtypes
546-
eps = 1e-06 # Ensure consistency between CPU and GPU.
547546
original_shape = bounding_boxes.shape
548547
bounding_boxes = bounding_boxes.reshape(-1, 8)
549548
original_bounding_boxes = original_bounding_boxes.reshape(-1, 8)
@@ -559,23 +558,21 @@ def _clamp_along_y_axis(
559558
case_b[..., 6].clamp_(0) # Clamp x4 to 0
560559
case_c = torch.zeros_like(case_b)
561560

562-
cond_a = (x1 < eps) & ~case_a.isnan().any(-1) # First point is outside left boundary
563-
cond_b = y1.isclose(y2, rtol=eps, atol=eps) | y3.isclose(y4, rtol=eps, atol=eps) # First line is nearly vertical
561+
cond_a = (x1 < 0) & ~case_a.isnan().any(-1) # First point is outside left boundary
562+
cond_b = y1.isclose(y2) | y3.isclose(y4) # First line is nearly vertical
564563
cond_c = (x1 <= 0) & (x2 <= 0) & (x3 <= 0) & (x4 <= 0) # All points outside left boundary
565564
cond_c = (
566565
cond_c
567-
| y1.isclose(y4, rtol=eps, atol=eps)
568-
| y2.isclose(y3, rtol=eps, atol=eps)
569-
| (cond_b & x1.isclose(x2, rtol=eps, atol=eps))
566+
| y1.isclose(y4)
567+
| y2.isclose(y3)
568+
| (cond_b & x1.isclose(x2))
570569
) # First line is nearly horizontal
571570

572571
for (cond, case) in zip(
573572
[cond_a, cond_b, cond_c],
574573
[case_a, case_b, case_c],
575574
):
576575
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
577-
if clamping_mode == "hard":
578-
bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0
579576

580577
if need_cast:
581578
bounding_boxes = bounding_boxes.to(dtype)

0 commit comments

Comments
 (0)