Skip to content

Commit 9b91d94

Browse files
Adjust hard clamping
Test Plan: ```bash pytest test/test_transforms_v2.py -k box -v ```
1 parent 90a578b commit 9b91d94

File tree

1 file changed

+40
-15
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+40
-15
lines changed

torchvision/transforms/v2/functional/_meta.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -466,29 +466,54 @@ def _clamp_y_intercept(
466466
then applies various constraints to ensure the clamping conditions are respected.
467467
"""
468468

469+
# Calculate slopes and y-intercepts for bounding boxes
469470
a, b = _get_slope_and_intercept(bounding_boxes)
470471
a1, a2, a3, a4 = a.unbind(-1)
471472
b1, b2, b3, b4 = b.unbind(-1)
472473

473-
# Clamp y-intercepts (soft clamping)
474+
# Get y-intercepts from original bounding boxes
475+
_, bm = _get_slope_and_intercept(original_bounding_boxes)
476+
b1m, b2m, b3m, b4m = bm.unbind(-1)
477+
478+
# Soft clamping: Clamp y-intercepts within canvas boundaries
474479
b1 = b2.clamp(b1, b3).clamp(0, canvas_size[0])
475480
b4 = b3.clamp(b2, b4).clamp(0, canvas_size[0])
476481

477482
if clamping_mode == "hard":
478-
# Get y-intercepts from original bounding boxes
479-
_, b = _get_slope_and_intercept(original_bounding_boxes)
480-
_, b2, b3, _ = b.unbind(-1)
481-
482-
# Set b1 and b4 to the average of their clamped values
483-
b1 = b4 = (b1.clamp(0, canvas_size[0]) + b4.clamp(0, canvas_size[0])) / 2
484-
485-
# Ensure b2 and b3 defined the box of maximum area after clamping b1 and b4
486-
b2.clamp_(b1 * a2 / a1, b4).clamp_((a1 - a2) * canvas_size[1] + b1)
487-
b2.clamp_(b3 * a2 / a3, b4).clamp_((a3 - a2) * canvas_size[1] + b3)
488-
b3.clamp_(max=canvas_size[0] * (1 - a3 / a4) + b4 * a3 / a4)
489-
b3.clamp_(max=canvas_size[0] * (1 - a3 / a2) + b2 * a3 / a2)
490-
b3.clamp_(b1, (a2 - a3) * canvas_size[1] + b2)
491-
b3.clamp_(b1, (a4 - a3) * canvas_size[1] + b4)
483+
# Hard clamping: Average b1 and b4, and adjust b2 and b3 for maximum area
484+
b1 = b4 = (b1 + b4) / 2
485+
486+
# Calculate candidate values for b2 based on geometric constraints
487+
b2_candidates = torch.stack(
488+
[
489+
b1 * a2 / a1, # Constraint at y=0
490+
b3 * a2 / a3, # Constraint at y=0
491+
(a1 - a2) * canvas_size[1] + b1, # Constraint at x=canvas_width
492+
(a3 - a2) * canvas_size[1] + b3, # Constraint at x=canvas_width
493+
],
494+
dim=1,
495+
)
496+
# Take maximum value that doesn't exceed original b2
497+
b2 = torch.max(b2_candidates, dim=1)[0].clamp(max=b2)
498+
499+
# Calculate candidate values for b3 based on geometric constraints
500+
b3_candidates = torch.stack(
501+
[
502+
canvas_size[0] * (1 - a3 / a4) + b4 * a3 / a4, # Constraint at y=canvas_height
503+
canvas_size[0] * (1 - a3 / a2) + b2 * a3 / a2, # Constraint at y=canvas_height
504+
(a2 - a3) * canvas_size[1] + b2, # Constraint at x=canvas_width
505+
(a4 - a3) * canvas_size[1] + b4, # Constraint at x=canvas_width
506+
],
507+
dim=1,
508+
)
509+
# Take minimum value that doesn't go below original b3
510+
b3 = torch.min(b3_candidates, dim=1)[0].clamp(min=b3)
511+
512+
# Final clamping to ensure y-intercepts are within original box bounds
513+
b1.clamp_(b1m, b3m)
514+
b3.clamp_(b1m, b3m)
515+
b2.clamp_(b2m, b4m)
516+
b4.clamp_(b2m, b4m)
492517

493518
return torch.stack([b1, b2, b3, b4], dim=-1)
494519

0 commit comments

Comments
 (0)