Skip to content

Commit 2535edf

Browse files
sergiussergius
authored andcommitted
Implemented GaussianNoise compatibility for uint8 inputs
1 parent 98f8b37 commit 2535edf

File tree

3 files changed

+75
-17
lines changed

3 files changed

+75
-17
lines changed

test/test_transforms_v2.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3978,7 +3978,7 @@ class TestGaussianNoise:
39783978
"make_input",
39793979
[make_image_tensor, make_image, make_video],
39803980
)
3981-
def test_kernel(self, make_input):
3981+
def test_kernel_float(self, make_input):
39823982
check_kernel(
39833983
F.gaussian_noise,
39843984
make_input(dtype=torch.float32),
@@ -3990,9 +3990,28 @@ def test_kernel(self, make_input):
39903990
"make_input",
39913991
[make_image_tensor, make_image, make_video],
39923992
)
3993-
def test_functional(self, make_input):
3993+
def test_kernel_uint8(self, make_input):
3994+
check_kernel(
3995+
F.gaussian_noise,
3996+
make_input(dtype=torch.uint8),
3997+
# This cannot pass because the noise on a batch in not per-image
3998+
check_batched_vs_unbatched=False,
3999+
)
4000+
4001+
@pytest.mark.parametrize(
4002+
"make_input",
4003+
[make_image_tensor, make_image, make_video],
4004+
)
4005+
def test_functional_float(self, make_input):
39944006
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))
39954007

4008+
@pytest.mark.parametrize(
4009+
"make_input",
4010+
[make_image_tensor, make_image, make_video],
4011+
)
4012+
def test_functional_uint8(self, make_input):
4013+
check_functional(F.gaussian_noise, make_input(dtype=torch.uint8))
4014+
39964015
@pytest.mark.parametrize(
39974016
("kernel", "input_type"),
39984017
[
@@ -4008,10 +4027,11 @@ def test_functional_signature(self, kernel, input_type):
40084027
"make_input",
40094028
[make_image_tensor, make_image, make_video],
40104029
)
4011-
def test_transform(self, make_input):
4030+
def test_transform_float(self, make_input):
40124031
def adapter(_, input, __):
4013-
# This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
4014-
# Same for PIL images
4032+
# We have two different implementations for floats and uint8
4033+
# To test this implementation we'll convert the auto-generated uint8 tensors to float32
4034+
# We don't support other int dtypes nor pil images
40154035
for key, value in input.items():
40164036
if isinstance(value, torch.Tensor) and not value.is_floating_point():
40174037
input[key] = value.to(torch.float32)
@@ -4021,11 +4041,29 @@ def adapter(_, input, __):
40214041

40224042
check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter)
40234043

4044+
@pytest.mark.parametrize(
4045+
"make_input",
4046+
[make_image_tensor, make_image, make_video],
4047+
)
4048+
def test_transform_uint8(self, make_input):
4049+
def adapter(_, input, __):
4050+
# We have two different implementations for floats and uint8
4051+
# To test this implementation we'll convert every tensor to uint8
4052+
# We don't support other int dtypes nor pil images
4053+
for key, value in input.items():
4054+
if isinstance(value, torch.Tensor) and not value.dtype != torch.uint8:
4055+
input[key] = value.to(torch.uint8)
4056+
if isinstance(value, PIL.Image.Image):
4057+
input[key] = F.pil_to_tensor(value).to(torch.uint8)
4058+
return input
4059+
4060+
check_transform(transforms.GaussianNoise(), make_input(dtype=torch.uint8), check_sample_input=adapter)
4061+
40244062
def test_bad_input(self):
40254063
with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."):
40264064
F.gaussian_noise(make_image_pil())
4027-
with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"):
4028-
F.gaussian_noise(make_image(dtype=torch.uint8))
4065+
with pytest.raises(ValueError, match="Input tensor is expected to be in uint8 or float dtype"):
4066+
F.gaussian_noise(make_image(dtype=torch.int32))
40294067
with pytest.raises(ValueError, match="sigma shouldn't be negative"):
40304068
F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1)
40314069

torchvision/transforms/v2/_misc.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,22 @@ class GaussianNoise(Transform):
214214
Each image or frame in a batch will be transformed independently i.e. the
215215
noise added to each image will be different.
216216
217-
The input tensor is also expected to be of float dtype in ``[0, 1]``.
218-
This transform does not support PIL images.
217+
The input tensor is also expected to be of float dtype in ``[0, 1]``,
218+
or of ``uint8`` dtype in ``[0, 255]``. This transform does not support PIL
219+
images.
220+
221+
Regardless of the dtype used, the parameters of the function use the same
222+
scale, so a ``mean`` parameter of 0.5 will result in an average value
223+
increase of 0.5 units for float images, and an average increase of 127.5
224+
units for ``uint8`` images.
219225
220226
Args:
221227
mean (float): Mean of the sampled normal distribution. Default is 0.
222228
sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1.
223-
clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True.
229+
clip (bool, optional): Whether to clip the values after adding noise, be it to
230+
``[0, 1]`` for floats or to ``[0, 255]`` for ``uint8``. Setting this parameter to
231+
``False`` may cause unsigned integer overflows with uint8 inputs.
232+
Default is True.
224233
"""
225234

226235
def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None:

torchvision/transforms/v2/functional/_misc.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,27 @@ def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, cl
195195
@_register_kernel_internal(gaussian_noise, torch.Tensor)
196196
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
197197
def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
198-
if not image.is_floating_point():
199-
raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}")
200198
if sigma < 0:
201199
raise ValueError(f"sigma shouldn't be negative. Got {sigma}")
202200

203-
noise = mean + torch.randn_like(image) * sigma
204-
out = image + noise
205-
if clip:
206-
out = torch.clamp(out, 0, 1)
207-
return out
201+
if image.is_floating_point():
202+
noise = mean + torch.randn_like(image) * sigma
203+
out = image + noise
204+
if clip:
205+
out = torch.clamp(out, 0, 1)
206+
return out
207+
208+
elif image.dtype == torch.uint8:
209+
# Convert to intermediate dtype int16 to add to input more efficiently
210+
noise = ((mean * 255) + torch.randn_like(image, dtype=torch.float32) * (sigma * 255)).to(torch.int16)
211+
out = image + noise
212+
213+
if clip:
214+
out = torch.clamp(out, 0, 255)
215+
return out.to(torch.uint8)
216+
217+
else:
218+
raise ValueError(f"Input tensor is expected to be in uint8 or float dtype, got dtype={image.dtype}")
208219

209220

210221
@_register_kernel_internal(gaussian_noise, tv_tensors.Video)

0 commit comments

Comments
 (0)