diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 416b2e4facb..b05b04cca89 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2237,7 +2237,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) def test_functional_bounding_boxes_correctness(self, format, angle, expand, center): - bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none") + bounding_boxes = make_bounding_boxes(format=format, clamping_mode=None) actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center) expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center) @@ -2249,7 +2249,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) def test_transform_bounding_boxes_correctness(self, format, expand, center, seed): - bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none") + bounding_boxes = make_bounding_boxes(format=format, clamping_mode=None) transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center) @@ -4428,7 +4428,7 @@ def test_functional_bounding_boxes_correctness(self, format): # _reference_resized_crop_bounding_boxes we are fusing the crop and the # resize operation, where none of the croppings happen - particularly, # the intermediate one. - bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, clamping_mode="none") + bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, clamping_mode=None) actual = F.resized_crop(bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE) expected = self._reference_resized_crop_bounding_boxes( @@ -5507,7 +5507,7 @@ def test_correctness_image(self, mean, std, dtype, fn): class TestClampBoundingBoxes: @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) - @pytest.mark.parametrize("clamping_mode", ("hard", "none")) # TODOBB add soft + @pytest.mark.parametrize("clamping_mode", ("hard", None)) # TODOBB add soft @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel(self, format, clamping_mode, dtype, device): @@ -5521,7 +5521,7 @@ def test_kernel(self, format, clamping_mode, dtype, device): ) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) - @pytest.mark.parametrize("clamping_mode", ("hard", "none")) # TODOBB add soft + @pytest.mark.parametrize("clamping_mode", ("hard", None)) # TODOBB add soft def test_functional(self, format, clamping_mode): check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format, clamping_mode=clamping_mode)) @@ -5531,7 +5531,7 @@ def test_errors(self): format, canvas_size = input_tv_tensor.format, input_tv_tensor.canvas_size for format_, canvas_size_, clamping_mode_ in itertools.product( - (format, None), (canvas_size, None), (input_tv_tensor.clamping_mode, None) + (format, None), (canvas_size, None), (input_tv_tensor.clamping_mode, "auto") ): with pytest.raises( ValueError, @@ -5549,8 +5549,8 @@ def test_transform(self): check_transform(transforms.ClampBoundingBoxes(), make_bounding_boxes()) @pytest.mark.parametrize("rotated", (True, False)) - @pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none")) - @pytest.mark.parametrize("clamping_mode", ("hard", "none", None)) # TODOBB add soft here. + @pytest.mark.parametrize("constructor_clamping_mode", ("hard", None)) + @pytest.mark.parametrize("clamping_mode", ("hard", None, "auto")) # TODOBB add soft here. @pytest.mark.parametrize("pass_pure_tensor", (True, False)) @pytest.mark.parametrize("fn", [F.clamp_bounding_boxes, transform_cls_to_functional(transforms.ClampBoundingBoxes)]) def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn): @@ -5559,15 +5559,15 @@ def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode, # functional (or to the class) relies on the box's `.clamping_mode` # attribute # - That clamping happens when it should, and only when it should, i.e. - # when the clamping mode is not "none". It doesn't validate the - # nunmerical results, only that clamping happened. For that, we create + # when the clamping mode is not None. It doesn't validate the + # numerical results, only that clamping happened. For that, we create # a large 100x100 box inside of a small 10x10 image. if pass_pure_tensor and fn is not F.clamp_bounding_boxes: # Only the functional supports pure tensors, not the class return - if pass_pure_tensor and clamping_mode is None: - # cannot leave clamping_mode=None when passing pure tensor + if pass_pure_tensor and clamping_mode == "auto": + # cannot leave clamping_mode="auto" when passing pure tensor return if rotated: @@ -5591,8 +5591,8 @@ def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode, else: out = fn(boxes, clamping_mode=clamping_mode) - clamping_mode_prevailing = constructor_clamping_mode if clamping_mode is None else clamping_mode - if clamping_mode_prevailing == "none": + clamping_mode_prevailing = constructor_clamping_mode if clamping_mode == "auto" else clamping_mode + if clamping_mode_prevailing is None: assert_equal(boxes, out) # should be a pass-through else: assert_equal(out, expected_clamped_output) @@ -5600,8 +5600,8 @@ def test_clamping_mode(self, rotated, constructor_clamping_mode, clamping_mode, class TestSetClampingMode: @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) - @pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none")) # TODOBB add soft - @pytest.mark.parametrize("desired_clamping_mode", ("hard", "none")) # TODOBB add soft + @pytest.mark.parametrize("constructor_clamping_mode", ("hard", None)) # TODOBB add soft + @pytest.mark.parametrize("desired_clamping_mode", ("hard", None)) # TODOBB add soft def test_setter(self, format, constructor_clamping_mode, desired_clamping_mode): in_boxes = make_bounding_boxes(format=format, clamping_mode=constructor_clamping_mode) @@ -5611,7 +5611,7 @@ def test_setter(self, format, constructor_clamping_mode, desired_clamping_mode): assert out_boxes.clamping_mode == desired_clamping_mode @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) - @pytest.mark.parametrize("constructor_clamping_mode", ("hard", "none")) # TODOBB add soft + @pytest.mark.parametrize("constructor_clamping_mode", ("hard", None)) # TODOBB add soft def test_pipeline_no_leak(self, format, constructor_clamping_mode): class AssertClampingMode(transforms.Transform): def __init__(self, expected_clamping_mode): @@ -5626,12 +5626,12 @@ def transform(self, inpt, _): t = transforms.Compose( [ - transforms.SetClampingMode("none"), - AssertClampingMode("none"), + transforms.SetClampingMode(None), + AssertClampingMode(None), transforms.SetClampingMode("hard"), AssertClampingMode("hard"), - transforms.SetClampingMode("none"), - AssertClampingMode("none"), + transforms.SetClampingMode(None), + AssertClampingMode(None), transforms.ClampBoundingBoxes("hard"), ] ) @@ -5643,7 +5643,7 @@ def transform(self, inpt, _): # assert that the output boxes clamping_mode is the one set by the last SetClampingMode. # ClampBoundingBoxes doesn't set clamping_mode. - assert out_boxes.clamping_mode == "none" + assert out_boxes.clamping_mode is None class TestClampKeyPoints: diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index 34e44045cbc..2b40f21392f 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Union from torchvision import tv_tensors from torchvision.transforms.v2 import functional as F, Transform @@ -34,7 +34,9 @@ class ClampBoundingBoxes(Transform): """ - def __init__(self, clamping_mode: Optional[CLAMPING_MODE_TYPE] = None) -> None: + # TODOBB consider "auto" to be a Literal, make sur torchscript is still happy + # TODOBB validate clamping_mode + def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None: super().__init__() self.clamping_mode = clamping_mode diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index bca7a6de088..2370fe72fca 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import PIL.Image import torch @@ -376,7 +376,7 @@ def _clamp_bounding_boxes( canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE, ) -> torch.Tensor: - if clamping_mode is not None and clamping_mode == "none": + if clamping_mode is None: return bounding_boxes.clone() # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth @@ -479,7 +479,7 @@ def _clamp_y_intercept( b1 = b2.clamp(b1, b3).clamp(0, canvas_size[0]) b4 = b3.clamp(b2, b4).clamp(0, canvas_size[0]) - if clamping_mode == "hard": + if clamping_mode is not None and clamping_mode == "hard": # Hard clamping: Average b1 and b4, and adjust b2 and b3 for maximum area b1 = b4 = (b1 + b4) / 2 @@ -574,7 +574,7 @@ def _clamp_along_y_axis( [case_a, case_b, case_c], ): bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes) - if clamping_mode == "hard": + if clamping_mode is not None and clamping_mode == "hard": bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0 if need_cast: @@ -610,7 +610,7 @@ def _clamp_rotated_bounding_boxes( Returns: torch.Tensor: Clamped bounding boxes in the original format and shape """ - if clamping_mode is not None and clamping_mode == "none": + if clamping_mode is None: return bounding_boxes.clone() original_shape = bounding_boxes.shape dtype = bounding_boxes.dtype @@ -657,7 +657,7 @@ def clamp_bounding_boxes( inpt: torch.Tensor, format: Optional[BoundingBoxFormat] = None, canvas_size: Optional[tuple[int, int]] = None, - clamping_mode: Optional[CLAMPING_MODE_TYPE] = None, + clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto", ) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.ClampBoundingBoxes` for details.""" if not torch.jit.is_scripting(): @@ -665,7 +665,7 @@ def clamp_bounding_boxes( if torch.jit.is_scripting() or is_pure_tensor(inpt): - if format is None or canvas_size is None or clamping_mode is None: + if format is None or canvas_size is None or (clamping_mode is not None and clamping_mode == "auto"): raise ValueError("For pure tensor inputs, `format`, `canvas_size` and `clamping_mode` have to be passed.") if tv_tensors.is_rotated_bounding_format(format): return _clamp_rotated_bounding_boxes( @@ -676,7 +676,7 @@ def clamp_bounding_boxes( elif isinstance(inpt, tv_tensors.BoundingBoxes): if format is not None or canvas_size is not None: raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.") - if clamping_mode is None: + if clamping_mode is not None and clamping_mode == "auto": clamping_mode = inpt.clamping_mode if tv_tensors.is_rotated_bounding_format(inpt.format): output = _clamp_rotated_bounding_boxes( diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index 72a2825aad1..4ad6d978bfb 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -3,7 +3,7 @@ from collections.abc import Mapping, Sequence from enum import Enum -from typing import Any +from typing import Any, Optional import torch from torch.utils._pytree import tree_flatten @@ -48,8 +48,7 @@ def is_rotated_bounding_format(format: BoundingBoxFormat) -> bool: # TODOBB consider making this a Literal instead. Tried briefly and got # torchscript errors, leaving to str for now. -# CLAMPING_MODE_TYPE = Literal["hard", "soft", "none"] -CLAMPING_MODE_TYPE = str +CLAMPING_MODE_TYPE = Optional[str] # TODOBB All docs. Add any new API to rst files, add tutorial[s].