Skip to content

Commit 9e3f7c0

Browse files
Adjust soft clamping
1 parent 42bae57 commit 9e3f7c0

File tree

1 file changed

+9
-8
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+9
-8
lines changed

torchvision/transforms/v2/functional/_meta.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def _clamp_y_intercept(
446446
bounding_boxes: torch.Tensor,
447447
original_bounding_boxes: torch.Tensor,
448448
canvas_size: tuple[int, int],
449-
clamping: str = "hard",
449+
clamping_mode: str = "hard",
450450
) -> torch.Tensor:
451451
"""
452452
Apply clamping to bounding box y-intercepts. This function handles two clamping strategies:
@@ -464,10 +464,10 @@ def _clamp_y_intercept(
464464
b1, b2, b3, b4 = b.unbind(-1)
465465

466466
# Clamp y-intercepts (soft clamping)
467-
b1 = b2.clamp(0).clamp(b1, b3)
468-
b4 = b3.clamp(max=canvas_size[0]).clamp(b2, b4)
467+
b1 = b2.clamp(b1, b3).clamp(0, canvas_size[0])
468+
b4 = b3.clamp(b2, b4).clamp(0, canvas_size[0])
469469

470-
if clamping == "hard":
470+
if clamping_mode == "hard":
471471
# Get y-intercepts from original bounding boxes
472472
_, b = _get_slope_and_intercept(original_bounding_boxes)
473473
_, b2, b3, _ = b.unbind(-1)
@@ -490,7 +490,7 @@ def _clamp_along_y_axis(
490490
bounding_boxes: torch.Tensor,
491491
original_bounding_boxes: torch.Tensor,
492492
canvas_size: tuple[int, int],
493-
clamping: str = "hard",
493+
clamping_mode: str = "hard",
494494
) -> torch.Tensor:
495495
"""
496496
Adjusts bounding boxes along the y-axis based on specific conditions.
@@ -503,7 +503,7 @@ def _clamp_along_y_axis(
503503
bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates.
504504
original_bounding_boxes (torch.Tensor): The original bounding boxes before any clamping is applied.
505505
canvas_size (tuple[int, int]): The size of the canvas as (height, width).
506-
clamping (str, optional): The clamping strategy to use. Defaults to "hard".
506+
clamping_mode (str, optional): The clamping strategy to use. Defaults to "hard".
507507
508508
Returns:
509509
torch.Tensor: The adjusted bounding boxes.
@@ -519,7 +519,7 @@ def _clamp_along_y_axis(
519519
# Calculate slopes (a) and y-intercepts (b) for all lines in the bounding boxes
520520
a, b = _get_slope_and_intercept(bounding_boxes)
521521
x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.unbind(-1)
522-
b = _clamp_y_intercept(bounding_boxes, original_bounding_boxes, canvas_size, clamping)
522+
b = _clamp_y_intercept(bounding_boxes, original_bounding_boxes, canvas_size, clamping_mode)
523523

524524
case_a = _get_intersection_point(a, b)
525525
case_b = bounding_boxes.clone()
@@ -537,7 +537,8 @@ def _clamp_along_y_axis(
537537
[case_a, case_b, case_c],
538538
):
539539
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
540-
bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0
540+
if clamping_mode == "hard":
541+
bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0
541542

542543
if need_cast:
543544
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):

0 commit comments

Comments
 (0)