@@ -3993,7 +3993,7 @@ class TestGaussianNoise:
39933993        "make_input" , 
39943994        [make_image_tensor , make_image , make_video ], 
39953995    ) 
3996-     def  test_kernel (self , make_input ):
3996+     def  test_kernel_float (self , make_input ):
39973997        check_kernel (
39983998            F .gaussian_noise ,
39993999            make_input (dtype = torch .float32 ),
@@ -4005,9 +4005,28 @@ def test_kernel(self, make_input):
40054005        "make_input" , 
40064006        [make_image_tensor , make_image , make_video ], 
40074007    ) 
4008-     def  test_functional (self , make_input ):
4008+     def  test_kernel_uint8 (self , make_input ):
4009+         check_kernel (
4010+             F .gaussian_noise ,
4011+             make_input (dtype = torch .uint8 ),
4012+             # This cannot pass because the noise on a batch in not per-image 
4013+             check_batched_vs_unbatched = False ,
4014+         )
4015+ 
4016+     @pytest .mark .parametrize ( 
4017+         "make_input" , 
4018+         [make_image_tensor , make_image , make_video ], 
4019+     ) 
4020+     def  test_functional_float (self , make_input ):
40094021        check_functional (F .gaussian_noise , make_input (dtype = torch .float32 ))
40104022
4023+     @pytest .mark .parametrize ( 
4024+         "make_input" , 
4025+         [make_image_tensor , make_image , make_video ], 
4026+     ) 
4027+     def  test_functional_uint8 (self , make_input ):
4028+         check_functional (F .gaussian_noise , make_input (dtype = torch .uint8 ))
4029+ 
40114030    @pytest .mark .parametrize ( 
40124031        ("kernel" , "input_type" ), 
40134032        [ 
@@ -4023,10 +4042,11 @@ def test_functional_signature(self, kernel, input_type):
40234042        "make_input" , 
40244043        [make_image_tensor , make_image , make_video ], 
40254044    ) 
4026-     def  test_transform (self , make_input ):
4045+     def  test_transform_float (self , make_input ):
40274046        def  adapter (_ , input , __ ):
4028-             # This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32 
4029-             # Same for PIL images 
4047+             # We have two different implementations for floats and uint8 
4048+             # To test this implementation we'll convert the auto-generated uint8 tensors to float32 
4049+             # We don't support other int dtypes nor pil images 
40304050            for  key , value  in  input .items ():
40314051                if  isinstance (value , torch .Tensor ) and  not  value .is_floating_point ():
40324052                    input [key ] =  value .to (torch .float32 )
@@ -4036,11 +4056,29 @@ def adapter(_, input, __):
40364056
40374057        check_transform (transforms .GaussianNoise (), make_input (dtype = torch .float32 ), check_sample_input = adapter )
40384058
4059+     @pytest .mark .parametrize ( 
4060+         "make_input" , 
4061+         [make_image_tensor , make_image , make_video ], 
4062+     ) 
4063+     def  test_transform_uint8 (self , make_input ):
4064+         def  adapter (_ , input , __ ):
4065+             # We have two different implementations for floats and uint8 
4066+             # To test this implementation we'll convert every tensor to uint8 
4067+             # We don't support other int dtypes nor pil images 
4068+             for  key , value  in  input .items ():
4069+                 if  isinstance (value , torch .Tensor ) and  not  value .dtype  !=  torch .uint8 :
4070+                     input [key ] =  value .to (torch .uint8 )
4071+                 if  isinstance (value , PIL .Image .Image ):
4072+                     input [key ] =  F .pil_to_tensor (value ).to (torch .uint8 )
4073+             return  input 
4074+ 
4075+         check_transform (transforms .GaussianNoise (), make_input (dtype = torch .uint8 ), check_sample_input = adapter )
4076+ 
40394077    def  test_bad_input (self ):
40404078        with  pytest .raises (ValueError , match = "Gaussian Noise is not implemented for PIL images." ):
40414079            F .gaussian_noise (make_image_pil ())
4042-         with  pytest .raises (ValueError , match = "Input tensor is expected to be in float dtype" ):
4043-             F .gaussian_noise (make_image (dtype = torch .uint8 ))
4080+         with  pytest .raises (ValueError , match = "Input tensor is expected to be in uint8 or  float dtype" ):
4081+             F .gaussian_noise (make_image (dtype = torch .int32 ))
40444082        with  pytest .raises (ValueError , match = "sigma shouldn't be negative" ):
40454083            F .gaussian_noise (make_image (dtype = torch .float32 ), sigma = - 1 )
40464084
0 commit comments