Skip to content

Commit 646341d

Browse files
authored
Validate clamping_mode values (#9136)
1 parent fb3926e commit 646341d

File tree

5 files changed

+31
-14
lines changed

5 files changed

+31
-14
lines changed

test/test_transforms_v2.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5526,7 +5526,7 @@ def test_correctness_image(self, mean, std, dtype, fn):
55265526

55275527
class TestClampBoundingBoxes:
55285528
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5529-
@pytest.mark.parametrize("clamping_mode", ("hard", None)) # TODOBB add soft
5529+
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None))
55305530
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
55315531
@pytest.mark.parametrize("device", cpu_and_cuda())
55325532
def test_kernel(self, format, clamping_mode, dtype, device):
@@ -5542,7 +5542,7 @@ def test_kernel(self, format, clamping_mode, dtype, device):
55425542
)
55435543

55445544
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5545-
@pytest.mark.parametrize("clamping_mode", ("hard", None)) # TODOBB add soft
5545+
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None))
55465546
def test_functional(self, format, clamping_mode):
55475547
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format, clamping_mode=clamping_mode))
55485548

@@ -5566,12 +5566,17 @@ def test_errors(self):
55665566
):
55675567
F.clamp_bounding_boxes(input_tv_tensor, format=format_, canvas_size=canvas_size_)
55685568

5569+
with pytest.raises(ValueError, match="clamping_mode must be soft,"):
5570+
F.clamp_bounding_boxes(input_tv_tensor, clamping_mode="bad")
5571+
with pytest.raises(ValueError, match="clamping_mode must be soft,"):
5572+
transforms.ClampBoundingBoxes(clamping_mode="bad")(input_tv_tensor)
5573+
55695574
def test_transform(self):
55705575
check_transform(transforms.ClampBoundingBoxes(), make_bounding_boxes())
55715576

55725577
@pytest.mark.parametrize("rotated", (True, False))
5573-
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", None))
5574-
@pytest.mark.parametrize("clamping_mode", ("hard", None, "auto")) # TODOBB add soft here.
5578+
@pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None))
5579+
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None, "auto"))
55755580
@pytest.mark.parametrize("pass_pure_tensor", (True, False))
55765581
@pytest.mark.parametrize("fn", [F.clamp_bounding_boxes, transform_cls_to_functional(transforms.ClampBoundingBoxes)])
55775582
def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn):
@@ -5624,8 +5629,8 @@ def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode,
56245629

56255630
class TestSetClampingMode:
56265631
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5627-
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", None)) # TODOBB add soft
5628-
@pytest.mark.parametrize("desired_clamping_mode", ("hard", None)) # TODOBB add soft
5632+
@pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None))
5633+
@pytest.mark.parametrize("desired_clamping_mode", ("soft", "hard", None))
56295634
def test_setter(self, format, constructor_clamping_mode, desired_clamping_mode):
56305635

56315636
in_boxes = make_bounding_boxes(format=format, clamping_mode=constructor_clamping_mode)
@@ -5635,7 +5640,7 @@ def test_setter(self, format, constructor_clamping_mode, desired_clamping_mode):
56355640
assert out_boxes.clamping_mode == desired_clamping_mode
56365641

56375642
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
5638-
@pytest.mark.parametrize("constructor_clamping_mode", ("hard", None)) # TODOBB add soft
5643+
@pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None))
56395644
def test_pipeline_no_leak(self, format, constructor_clamping_mode):
56405645
class AssertClampingMode(transforms.Transform):
56415646
def __init__(self, expected_clamping_mode):
@@ -5669,6 +5674,10 @@ def transform(self, inpt, _):
56695674
# ClampBoundingBoxes doesn't set clamping_mode.
56705675
assert out_boxes.clamping_mode is None
56715676

5677+
def test_error(self):
5678+
with pytest.raises(ValueError, match="clamping_mode must be"):
5679+
transforms.SetClampingMode("bad")
5680+
56725681

56735682
class TestClampKeyPoints:
56745683
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])

test/test_tv_tensors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,14 @@ def test_return_type_input():
432432
tv_tensors.set_return_type("tensor")
433433

434434

435-
def test_box_clamping_mode_default():
435+
def test_box_clamping_mode_default_and_error():
436436
assert (
437437
tv_tensors.BoundingBoxes([0.0, 0.0, 10.0, 10.0], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft"
438438
)
439439
assert (
440440
tv_tensors.BoundingBoxes([0.0, 0.0, 10.0, 10.0, 0.0], format="XYWHR", canvas_size=(100, 100)).clamping_mode
441441
== "soft"
442442
)
443+
444+
with pytest.raises(ValueError, match="clamping_mode must be"):
445+
tv_tensors.BoundingBoxes([0, 0, 10, 10], format="XYXY", canvas_size=(100, 100), clamping_mode="bad")

torchvision/transforms/v2/_meta.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ class ClampBoundingBoxes(Transform):
3434
3535
"""
3636

37-
# TODOBB consider "auto" to be a Literal, make sur torchscript is still happy
38-
# TODOBB validate clamping_mode
3937
def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None:
4038
super().__init__()
4139
self.clamping_mode = clamping_mode
@@ -63,9 +61,11 @@ class SetClampingMode(Transform):
6361

6462
def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None:
6563
super().__init__()
66-
# TODOBB validate mode
6764
self.clamping_mode = clamping_mode
6865

66+
if self.clamping_mode not in (None, "soft", "hard"):
67+
raise ValueError(f"clamping_mode must be soft, hard or None, got {clamping_mode}")
68+
6969
_transformed_types = (tv_tensors.BoundingBoxes,)
7070

7171
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:

torchvision/transforms/v2/functional/_meta.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,9 @@ def clamp_bounding_boxes(
640640
if not torch.jit.is_scripting():
641641
_log_api_usage_once(clamp_bounding_boxes)
642642

643+
if clamping_mode is not None and clamping_mode not in ("soft", "hard", "auto"):
644+
raise ValueError(f"clamping_mode must be soft, hard, auto or None, got {clamping_mode}")
645+
643646
if torch.jit.is_scripting() or is_pure_tensor(inpt):
644647

645648
if format is None or canvas_size is None or (clamping_mode is not None and clamping_mode == "auto"):

torchvision/tv_tensors/_bounding_boxes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def is_rotated_bounding_format(format: BoundingBoxFormat | str) -> bool:
5353
raise ValueError(f"format should be str or BoundingBoxFormat, got {type(format)}")
5454

5555

56-
# TODOBB consider making this a Literal instead. Tried briefly and got
57-
# torchscript errors, leaving to str for now.
56+
# This should ideally be a Literal, but torchscript fails.
5857
CLAMPING_MODE_TYPE = Optional[str]
5958

6059
# TODOBB All docs. Add any new API to rst files, add tutorial[s].
@@ -96,12 +95,15 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_
9695
tensor = tensor.unsqueeze(0)
9796
elif tensor.ndim != 2:
9897
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
98+
if clamping_mode is not None and clamping_mode not in ("hard", "soft"):
99+
raise ValueError(f"clamping_mode must be None, hard or soft, got {clamping_mode}.")
100+
99101
if isinstance(format, str):
100102
format = BoundingBoxFormat[format.upper()]
103+
101104
bounding_boxes = tensor.as_subclass(cls)
102105
bounding_boxes.format = format
103106
bounding_boxes.canvas_size = canvas_size
104-
# TODOBB validate values
105107
bounding_boxes.clamping_mode = clamping_mode
106108
return bounding_boxes
107109

0 commit comments

Comments
 (0)