Skip to content

Commit f1b4c7a

Browse files
vfdev-5pmeierNicolasHug
authored
Fixed sigma input type for v2.GaussianBlur (#7887)
Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent a2f8f8e commit f1b4c7a

File tree

5 files changed

+67
-65
lines changed

5 files changed

+67
-65
lines changed

test/test_transforms_v2.py

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -449,37 +449,6 @@ def test__get_params(self, fill, side_range):
449449
assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h
450450

451451

452-
class TestGaussianBlur:
453-
def test_assertions(self):
454-
with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
455-
transforms.GaussianBlur([10, 12, 14])
456-
457-
with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
458-
transforms.GaussianBlur(4)
459-
460-
with pytest.raises(
461-
TypeError, match="sigma should be a single int or float or a list/tuple with length 2 floats."
462-
):
463-
transforms.GaussianBlur(3, sigma=[1, 2, 3])
464-
465-
with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"):
466-
transforms.GaussianBlur(3, sigma=-1.0)
467-
468-
with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
469-
transforms.GaussianBlur(3, sigma=[2.0, 1.0])
470-
471-
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]])
472-
def test__get_params(self, sigma):
473-
transform = transforms.GaussianBlur(3, sigma=sigma)
474-
params = transform._get_params([])
475-
476-
if isinstance(sigma, float):
477-
assert params["sigma"][0] == params["sigma"][1] == 10
478-
else:
479-
assert sigma[0] <= params["sigma"][0] <= sigma[1]
480-
assert sigma[0] <= params["sigma"][1] <= sigma[1]
481-
482-
483452
class TestRandomPerspective:
484453
def test_assertions(self):
485454
with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"):
@@ -503,24 +472,18 @@ def test__get_params(self):
503472
class TestElasticTransform:
504473
def test_assertions(self):
505474

506-
with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"):
475+
with pytest.raises(TypeError, match="alpha should be a number or a sequence of numbers"):
507476
transforms.ElasticTransform({})
508477

509-
with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"):
478+
with pytest.raises(ValueError, match="alpha is a sequence its length should be 1 or 2"):
510479
transforms.ElasticTransform([1.0, 2.0, 3.0])
511480

512-
with pytest.raises(ValueError, match="alpha should be a sequence of floats"):
513-
transforms.ElasticTransform([1, 2])
514-
515-
with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"):
481+
with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
516482
transforms.ElasticTransform(1.0, {})
517483

518-
with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"):
484+
with pytest.raises(ValueError, match="sigma is a sequence its length should be 1 or 2"):
519485
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])
520486

521-
with pytest.raises(ValueError, match="sigma should be a sequence of floats"):
522-
transforms.ElasticTransform(1.0, [1, 2])
523-
524487
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
525488
transforms.ElasticTransform(1.0, 2.0, fill="abc")
526489

test/test_transforms_v2_refactored.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,3 +2859,46 @@ def test_transform_passthrough(self, make_input):
28592859
_, output = transform(make_image(self.INPUT_SIZE), input)
28602860

28612861
assert output is input
2862+
2863+
2864+
class TestGaussianBlur:
2865+
@pytest.mark.parametrize(
2866+
"make_input",
2867+
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
2868+
)
2869+
@pytest.mark.parametrize("device", cpu_and_cuda())
2870+
@pytest.mark.parametrize("sigma", [5, (0.5, 2)])
2871+
def test_transform(self, make_input, device, sigma):
2872+
check_transform(transforms.GaussianBlur(kernel_size=3, sigma=sigma), make_input(device=device))
2873+
2874+
def test_assertions(self):
2875+
with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
2876+
transforms.GaussianBlur([10, 12, 14])
2877+
2878+
with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
2879+
transforms.GaussianBlur(4)
2880+
2881+
with pytest.raises(ValueError, match="If sigma is a sequence its length should be 1 or 2. Got 3"):
2882+
transforms.GaussianBlur(3, sigma=[1, 2, 3])
2883+
2884+
with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
2885+
transforms.GaussianBlur(3, sigma=-1.0)
2886+
2887+
with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
2888+
transforms.GaussianBlur(3, sigma=[2.0, 1.0])
2889+
2890+
with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
2891+
transforms.GaussianBlur(3, sigma={})
2892+
2893+
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]])
2894+
def test__get_params(self, sigma):
2895+
transform = transforms.GaussianBlur(3, sigma=sigma)
2896+
params = transform._get_params([])
2897+
2898+
if isinstance(sigma, float):
2899+
assert params["sigma"][0] == params["sigma"][1] == sigma
2900+
elif isinstance(sigma, list) and len(sigma) == 1:
2901+
assert params["sigma"][0] == params["sigma"][1] == sigma[0]
2902+
else:
2903+
assert sigma[0] <= params["sigma"][0] <= sigma[1]
2904+
assert sigma[0] <= params["sigma"][1] <= sigma[1]

torchvision/transforms/v2/_geometry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
_get_fill,
2222
_setup_angle,
2323
_setup_fill_arg,
24-
_setup_float_or_seq,
24+
_setup_number_or_seq,
2525
_setup_size,
2626
get_bounding_boxes,
2727
has_all,
@@ -1060,8 +1060,8 @@ def __init__(
10601060
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
10611061
) -> None:
10621062
super().__init__()
1063-
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
1064-
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
1063+
self.alpha = _setup_number_or_seq(alpha, "alpha")
1064+
self.sigma = _setup_number_or_seq(sigma, "sigma")
10651065

10661066
self.interpolation = _check_interpolation(interpolation)
10671067
self.fill = fill

torchvision/transforms/v2/_misc.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchvision import transforms as _transforms, tv_tensors
1010
from torchvision.transforms.v2 import functional as F, Transform
1111

12-
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
12+
from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
1313

1414

1515
# TODO: do we want/need to expose this?
@@ -198,17 +198,10 @@ def __init__(
198198
if ks <= 0 or ks % 2 == 0:
199199
raise ValueError("Kernel size value should be an odd and positive number.")
200200

201-
if isinstance(sigma, (int, float)):
202-
if sigma <= 0:
203-
raise ValueError("If sigma is a single number, it must be positive.")
204-
sigma = float(sigma)
205-
elif isinstance(sigma, Sequence) and len(sigma) == 2:
206-
if not 0.0 < sigma[0] <= sigma[1]:
207-
raise ValueError("sigma values should be positive and of the form (min, max).")
208-
else:
209-
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
201+
self.sigma = _setup_number_or_seq(sigma, "sigma")
210202

211-
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
203+
if not 0.0 < self.sigma[0] <= self.sigma[1]:
204+
raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}")
212205

213206
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
214207
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()

torchvision/transforms/v2/_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,23 @@
1818
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
1919

2020

21-
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
22-
if not isinstance(arg, (float, Sequence)):
23-
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
24-
if isinstance(arg, Sequence) and len(arg) != req_size:
25-
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
21+
def _setup_number_or_seq(arg: Union[int, float, Sequence[Union[int, float]]], name: str) -> Sequence[float]:
22+
if not isinstance(arg, (int, float, Sequence)):
23+
raise TypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}")
24+
if isinstance(arg, Sequence) and len(arg) not in (1, 2):
25+
raise ValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}")
2626
if isinstance(arg, Sequence):
2727
for element in arg:
28-
if not isinstance(element, float):
29-
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
28+
if not isinstance(element, (int, float)):
29+
raise ValueError(f"{name} should be a sequence of numbers. Got {type(element)}")
3030

31-
if isinstance(arg, float):
31+
if isinstance(arg, (int, float)):
3232
arg = [float(arg), float(arg)]
33-
if isinstance(arg, (list, tuple)) and len(arg) == 1:
34-
arg = [arg[0], arg[0]]
33+
elif isinstance(arg, Sequence):
34+
if len(arg) == 1:
35+
arg = [float(arg[0]), float(arg[0])]
36+
else:
37+
arg = [float(arg[0]), float(arg[1])]
3538
return arg
3639

3740

0 commit comments

Comments
 (0)