Skip to content

Commit 3a8723e

Browse files
committed
Merge branch 'main' of github.com:pytorch/vision into autobbox
2 parents ca115ad + 80cb38e commit 3a8723e

File tree

2 files changed

+44
-19
lines changed

2 files changed

+44
-19
lines changed

test/test_transforms_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,7 +2237,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
22372237
@pytest.mark.parametrize("expand", [False, True])
22382238
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
22392239
def test_functional_bounding_boxes_correctness(self, format, angle, expand, center):
2240-
bounding_boxes = make_bounding_boxes(format=format, clamping_mode=None)
2240+
bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none")
22412241

22422242
actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center)
22432243
expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center)
@@ -2249,7 +2249,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent
22492249
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
22502250
@pytest.mark.parametrize("seed", list(range(5)))
22512251
def test_transform_bounding_boxes_correctness(self, format, expand, center, seed):
2252-
bounding_boxes = make_bounding_boxes(format=format, clamping_mode=None)
2252+
bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none")
22532253

22542254
transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center)
22552255

@@ -4428,7 +4428,7 @@ def test_functional_bounding_boxes_correctness(self, format):
44284428
# _reference_resized_crop_bounding_boxes we are fusing the crop and the
44294429
# resize operation, where none of the croppings happen - particularly,
44304430
# the intermediate one.
4431-
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, clamping_mode=None)
4431+
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, clamping_mode="none")
44324432

44334433
actual = F.resized_crop(bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE)
44344434
expected = self._reference_resized_crop_bounding_boxes(

torchvision/transforms/v2/functional/_meta.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -466,29 +466,54 @@ def _clamp_y_intercept(
466466
then applies various constraints to ensure the clamping conditions are respected.
467467
"""
468468

469+
# Calculate slopes and y-intercepts for bounding boxes
469470
a, b = _get_slope_and_intercept(bounding_boxes)
470471
a1, a2, a3, a4 = a.unbind(-1)
471472
b1, b2, b3, b4 = b.unbind(-1)
472473

473-
# Clamp y-intercepts (soft clamping)
474+
# Get y-intercepts from original bounding boxes
475+
_, bm = _get_slope_and_intercept(original_bounding_boxes)
476+
b1m, b2m, b3m, b4m = bm.unbind(-1)
477+
478+
# Soft clamping: Clamp y-intercepts within canvas boundaries
474479
b1 = b2.clamp(b1, b3).clamp(0, canvas_size[0])
475480
b4 = b3.clamp(b2, b4).clamp(0, canvas_size[0])
476481

477-
if clamping_mode is not None and clamping_mode == "hard":
478-
# Get y-intercepts from original bounding boxes
479-
_, b = _get_slope_and_intercept(original_bounding_boxes)
480-
_, b2, b3, _ = b.unbind(-1)
481-
482-
# Set b1 and b4 to the average of their clamped values
483-
b1 = b4 = (b1.clamp(0, canvas_size[0]) + b4.clamp(0, canvas_size[0])) / 2
482+
if clamping_mode == "hard":
483+
# Hard clamping: Average b1 and b4, and adjust b2 and b3 for maximum area
484+
b1 = b4 = (b1 + b4) / 2
485+
486+
# Calculate candidate values for b2 based on geometric constraints
487+
b2_candidates = torch.stack(
488+
[
489+
b1 * a2 / a1, # Constraint at y=0
490+
b3 * a2 / a3, # Constraint at y=0
491+
(a1 - a2) * canvas_size[1] + b1, # Constraint at x=canvas_width
492+
(a3 - a2) * canvas_size[1] + b3, # Constraint at x=canvas_width
493+
],
494+
dim=1,
495+
)
496+
# Take maximum value that doesn't exceed original b2
497+
b2 = torch.max(b2_candidates, dim=1)[0].clamp(max=b2)
498+
499+
# Calculate candidate values for b3 based on geometric constraints
500+
b3_candidates = torch.stack(
501+
[
502+
canvas_size[0] * (1 - a3 / a4) + b4 * a3 / a4, # Constraint at y=canvas_height
503+
canvas_size[0] * (1 - a3 / a2) + b2 * a3 / a2, # Constraint at y=canvas_height
504+
(a2 - a3) * canvas_size[1] + b2, # Constraint at x=canvas_width
505+
(a4 - a3) * canvas_size[1] + b4, # Constraint at x=canvas_width
506+
],
507+
dim=1,
508+
)
509+
# Take minimum value that doesn't go below original b3
510+
b3 = torch.min(b3_candidates, dim=1)[0].clamp(min=b3)
484511

485-
# Ensure b2 and b3 defined the box of maximum area after clamping b1 and b4
486-
b2.clamp_(b1 * a2 / a1, b4).clamp_((a1 - a2) * canvas_size[1] + b1)
487-
b2.clamp_(b3 * a2 / a3, b4).clamp_((a3 - a2) * canvas_size[1] + b3)
488-
b3.clamp_(max=canvas_size[0] * (1 - a3 / a4) + b4 * a3 / a4)
489-
b3.clamp_(max=canvas_size[0] * (1 - a3 / a2) + b2 * a3 / a2)
490-
b3.clamp_(b1, (a2 - a3) * canvas_size[1] + b2)
491-
b3.clamp_(b1, (a4 - a3) * canvas_size[1] + b4)
512+
# Final clamping to ensure y-intercepts are within original box bounds
513+
b1.clamp_(b1m, b3m)
514+
b3.clamp_(b1m, b3m)
515+
b2.clamp_(b2m, b4m)
516+
b4.clamp_(b2m, b4m)
492517

493518
return torch.stack([b1, b2, b3, b4], dim=-1)
494519

@@ -549,7 +574,7 @@ def _clamp_along_y_axis(
549574
[case_a, case_b, case_c],
550575
):
551576
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
552-
if clamping_mode is not None and clamping_mode == "hard":
577+
if clamping_mode == "hard":
553578
bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0
554579

555580
if need_cast:

0 commit comments

Comments
 (0)