@@ -3978,7 +3978,7 @@ class TestGaussianNoise:
39783978 "make_input" ,
39793979 [make_image_tensor , make_image , make_video ],
39803980 )
3981- def test_kernel (self , make_input ):
3981+ def test_kernel_float (self , make_input ):
39823982 check_kernel (
39833983 F .gaussian_noise ,
39843984 make_input (dtype = torch .float32 ),
@@ -3990,9 +3990,28 @@ def test_kernel(self, make_input):
39903990 "make_input" ,
39913991 [make_image_tensor , make_image , make_video ],
39923992 )
3993- def test_functional (self , make_input ):
3993+ def test_kernel_uint8 (self , make_input ):
3994+ check_kernel (
3995+ F .gaussian_noise ,
3996+ make_input (dtype = torch .uint8 ),
3997+ # This cannot pass because the noise on a batch in not per-image
3998+ check_batched_vs_unbatched = False ,
3999+ )
4000+
4001+ @pytest .mark .parametrize (
4002+ "make_input" ,
4003+ [make_image_tensor , make_image , make_video ],
4004+ )
4005+ def test_functional_float (self , make_input ):
39944006 check_functional (F .gaussian_noise , make_input (dtype = torch .float32 ))
39954007
4008+ @pytest .mark .parametrize (
4009+ "make_input" ,
4010+ [make_image_tensor , make_image , make_video ],
4011+ )
4012+ def test_functional_uint8 (self , make_input ):
4013+ check_functional (F .gaussian_noise , make_input (dtype = torch .uint8 ))
4014+
39964015 @pytest .mark .parametrize (
39974016 ("kernel" , "input_type" ),
39984017 [
@@ -4008,10 +4027,11 @@ def test_functional_signature(self, kernel, input_type):
40084027 "make_input" ,
40094028 [make_image_tensor , make_image , make_video ],
40104029 )
4011- def test_transform (self , make_input ):
4030+ def test_transform_float (self , make_input ):
40124031 def adapter (_ , input , __ ):
4013- # This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
4014- # Same for PIL images
4032+ # We have two different implementations for floats and uint8
4033+ # To test this implementation we'll convert the auto-generated uint8 tensors to float32
4034+ # We don't support other int dtypes nor pil images
40154035 for key , value in input .items ():
40164036 if isinstance (value , torch .Tensor ) and not value .is_floating_point ():
40174037 input [key ] = value .to (torch .float32 )
@@ -4021,11 +4041,29 @@ def adapter(_, input, __):
40214041
40224042 check_transform (transforms .GaussianNoise (), make_input (dtype = torch .float32 ), check_sample_input = adapter )
40234043
4044+ @pytest .mark .parametrize (
4045+ "make_input" ,
4046+ [make_image_tensor , make_image , make_video ],
4047+ )
4048+ def test_transform_uint8 (self , make_input ):
4049+ def adapter (_ , input , __ ):
4050+ # We have two different implementations for floats and uint8
4051+ # To test this implementation we'll convert every tensor to uint8
4052+ # We don't support other int dtypes nor pil images
4053+ for key , value in input .items ():
4054+ if isinstance (value , torch .Tensor ) and not value .dtype != torch .uint8 :
4055+ input [key ] = value .to (torch .uint8 )
4056+ if isinstance (value , PIL .Image .Image ):
4057+ input [key ] = F .pil_to_tensor (value ).to (torch .uint8 )
4058+ return input
4059+
4060+ check_transform (transforms .GaussianNoise (), make_input (dtype = torch .uint8 ), check_sample_input = adapter )
4061+
40244062 def test_bad_input (self ):
40254063 with pytest .raises (ValueError , match = "Gaussian Noise is not implemented for PIL images." ):
40264064 F .gaussian_noise (make_image_pil ())
4027- with pytest .raises (ValueError , match = "Input tensor is expected to be in float dtype" ):
4028- F .gaussian_noise (make_image (dtype = torch .uint8 ))
4065+ with pytest .raises (ValueError , match = "Input tensor is expected to be in uint8 or float dtype" ):
4066+ F .gaussian_noise (make_image (dtype = torch .int32 ))
40294067 with pytest .raises (ValueError , match = "sigma shouldn't be negative" ):
40304068 F .gaussian_noise (make_image (dtype = torch .float32 ), sigma = - 1 )
40314069
0 commit comments