@@ -3993,7 +3993,7 @@ class TestGaussianNoise:
39933993 "make_input" ,
39943994 [make_image_tensor , make_image , make_video ],
39953995 )
3996- def test_kernel (self , make_input ):
3996+ def test_kernel_float (self , make_input ):
39973997 check_kernel (
39983998 F .gaussian_noise ,
39993999 make_input (dtype = torch .float32 ),
@@ -4005,9 +4005,28 @@ def test_kernel(self, make_input):
40054005 "make_input" ,
40064006 [make_image_tensor , make_image , make_video ],
40074007 )
4008- def test_functional (self , make_input ):
4008+ def test_kernel_uint8 (self , make_input ):
4009+ check_kernel (
4010+ F .gaussian_noise ,
4011+ make_input (dtype = torch .uint8 ),
4012+ # This cannot pass because the noise on a batch in not per-image
4013+ check_batched_vs_unbatched = False ,
4014+ )
4015+
4016+ @pytest .mark .parametrize (
4017+ "make_input" ,
4018+ [make_image_tensor , make_image , make_video ],
4019+ )
4020+ def test_functional_float (self , make_input ):
40094021 check_functional (F .gaussian_noise , make_input (dtype = torch .float32 ))
40104022
4023+ @pytest .mark .parametrize (
4024+ "make_input" ,
4025+ [make_image_tensor , make_image , make_video ],
4026+ )
4027+ def test_functional_uint8 (self , make_input ):
4028+ check_functional (F .gaussian_noise , make_input (dtype = torch .uint8 ))
4029+
40114030 @pytest .mark .parametrize (
40124031 ("kernel" , "input_type" ),
40134032 [
@@ -4023,10 +4042,11 @@ def test_functional_signature(self, kernel, input_type):
40234042 "make_input" ,
40244043 [make_image_tensor , make_image , make_video ],
40254044 )
4026- def test_transform (self , make_input ):
4045+ def test_transform_float (self , make_input ):
40274046 def adapter (_ , input , __ ):
4028- # This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
4029- # Same for PIL images
4047+ # We have two different implementations for floats and uint8
4048+ # To test this implementation we'll convert the auto-generated uint8 tensors to float32
4049+ # We don't support other int dtypes nor pil images
40304050 for key , value in input .items ():
40314051 if isinstance (value , torch .Tensor ) and not value .is_floating_point ():
40324052 input [key ] = value .to (torch .float32 )
@@ -4036,11 +4056,29 @@ def adapter(_, input, __):
40364056
40374057 check_transform (transforms .GaussianNoise (), make_input (dtype = torch .float32 ), check_sample_input = adapter )
40384058
4059+ @pytest .mark .parametrize (
4060+ "make_input" ,
4061+ [make_image_tensor , make_image , make_video ],
4062+ )
4063+ def test_transform_uint8 (self , make_input ):
4064+ def adapter (_ , input , __ ):
4065+ # We have two different implementations for floats and uint8
4066+ # To test this implementation we'll convert every tensor to uint8
4067+ # We don't support other int dtypes nor pil images
4068+ for key , value in input .items ():
4069+ if isinstance (value , torch .Tensor ) and not value .dtype != torch .uint8 :
4070+ input [key ] = value .to (torch .uint8 )
4071+ if isinstance (value , PIL .Image .Image ):
4072+ input [key ] = F .pil_to_tensor (value ).to (torch .uint8 )
4073+ return input
4074+
4075+ check_transform (transforms .GaussianNoise (), make_input (dtype = torch .uint8 ), check_sample_input = adapter )
4076+
40394077 def test_bad_input (self ):
40404078 with pytest .raises (ValueError , match = "Gaussian Noise is not implemented for PIL images." ):
40414079 F .gaussian_noise (make_image_pil ())
4042- with pytest .raises (ValueError , match = "Input tensor is expected to be in float dtype" ):
4043- F .gaussian_noise (make_image (dtype = torch .uint8 ))
4080+ with pytest .raises (ValueError , match = "Input tensor is expected to be in uint8 or float dtype" ):
4081+ F .gaussian_noise (make_image (dtype = torch .int32 ))
40444082 with pytest .raises (ValueError , match = "sigma shouldn't be negative" ):
40454083 F .gaussian_noise (make_image (dtype = torch .float32 ), sigma = - 1 )
40464084
0 commit comments