diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 06d6514770d..751806aaef4 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3978,7 +3978,7 @@ class TestGaussianNoise: "make_input", [make_image_tensor, make_image, make_video], ) - def test_kernel(self, make_input): + def test_kernel_float(self, make_input): check_kernel( F.gaussian_noise, make_input(dtype=torch.float32), @@ -3990,9 +3990,28 @@ def test_kernel(self, make_input): "make_input", [make_image_tensor, make_image, make_video], ) - def test_functional(self, make_input): + def test_kernel_uint8(self, make_input): + check_kernel( + F.gaussian_noise, + make_input(dtype=torch.uint8), + # This cannot pass because the noise on a batch in not per-image + check_batched_vs_unbatched=False, + ) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_functional_float(self, make_input): check_functional(F.gaussian_noise, make_input(dtype=torch.float32)) + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_functional_uint8(self, make_input): + check_functional(F.gaussian_noise, make_input(dtype=torch.uint8)) + @pytest.mark.parametrize( ("kernel", "input_type"), [ @@ -4008,10 +4027,11 @@ def test_functional_signature(self, kernel, input_type): "make_input", [make_image_tensor, make_image, make_video], ) - def test_transform(self, make_input): + def test_transform_float(self, make_input): def adapter(_, input, __): - # This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32 - # Same for PIL images + # We have two different implementations for floats and uint8 + # To test this implementation we'll convert the auto-generated uint8 tensors to float32 + # We don't support other int dtypes nor pil images for key, value in input.items(): if isinstance(value, torch.Tensor) and not value.is_floating_point(): input[key] = value.to(torch.float32) @@ -4021,11 +4041,29 @@ def adapter(_, input, __): check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter) + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_transform_uint8(self, make_input): + def adapter(_, input, __): + # We have two different implementations for floats and uint8 + # To test this implementation we'll convert every tensor to uint8 + # We don't support other int dtypes nor pil images + for key, value in input.items(): + if isinstance(value, torch.Tensor) and not value.dtype != torch.uint8: + input[key] = value.to(torch.uint8) + if isinstance(value, PIL.Image.Image): + input[key] = F.pil_to_tensor(value).to(torch.uint8) + return input + + check_transform(transforms.GaussianNoise(), make_input(dtype=torch.uint8), check_sample_input=adapter) + def test_bad_input(self): with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."): F.gaussian_noise(make_image_pil()) - with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"): - F.gaussian_noise(make_image(dtype=torch.uint8)) + with pytest.raises(ValueError, match="Input tensor is expected to be in uint8 or float dtype"): + F.gaussian_noise(make_image(dtype=torch.int32)) with pytest.raises(ValueError, match="sigma shouldn't be negative"): F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index dfd521b13be..875f65d581c 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -214,13 +214,22 @@ class GaussianNoise(Transform): Each image or frame in a batch will be transformed independently i.e. the noise added to each image will be different. - The input tensor is also expected to be of float dtype in ``[0, 1]``. - This transform does not support PIL images. + The input tensor is also expected to be of float dtype in ``[0, 1]``, + or of ``uint8`` dtype in ``[0, 255]``. This transform does not support PIL + images. + + Regardless of the dtype used, the parameters of the function use the same + scale, so a ``mean`` parameter of 0.5 will result in an average value + increase of 0.5 units for float images, and an average increase of 127.5 + units for ``uint8`` images. Args: mean (float): Mean of the sampled normal distribution. Default is 0. sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1. - clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True. + clip (bool, optional): Whether to clip the values after adding noise, be it to + ``[0, 1]`` for floats or to ``[0, 255]`` for ``uint8``. Setting this parameter to + ``False`` may cause unsigned integer overflows with uint8 inputs. + Default is True. """ def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 797fe496afb..7ae3d45c658 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -195,16 +195,27 @@ def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, cl @_register_kernel_internal(gaussian_noise, torch.Tensor) @_register_kernel_internal(gaussian_noise, tv_tensors.Image) def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: - if not image.is_floating_point(): - raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}") if sigma < 0: raise ValueError(f"sigma shouldn't be negative. Got {sigma}") - noise = mean + torch.randn_like(image) * sigma - out = image + noise - if clip: - out = torch.clamp(out, 0, 1) - return out + if image.is_floating_point(): + noise = mean + torch.randn_like(image) * sigma + out = image + noise + if clip: + out = torch.clamp(out, 0, 1) + return out + + elif image.dtype == torch.uint8: + # Convert to intermediate dtype int16 to add to input more efficiently + noise = ((mean * 255) + torch.randn_like(image, dtype=torch.float32) * (sigma * 255)).to(torch.int16) + out = image + noise + + if clip: + out = torch.clamp(out, 0, 255) + return out.to(torch.uint8) + + else: + raise ValueError(f"Input tensor is expected to be in uint8 or float dtype, got dtype={image.dtype}") @_register_kernel_internal(gaussian_noise, tv_tensors.Video)