@@ -102,21 +102,23 @@ def is_lossless(self):
102
102
"""See base class."""
103
103
return True
104
104
105
- def common_asserts_for_test_data (self , data ):
105
+ def common_asserts_for_test_data (self , data , rtol = 1e-6 , atol = 1e-6 ):
106
106
"""See base class."""
107
107
encoded_x = data .encoded_x [
108
108
stages_impl .HadamardEncodingStage .ENCODED_VALUES_KEY ]
109
- self .assertAllClose (data .x , data .decoded_x )
109
+ self .assertAllClose (data .x , data .decoded_x , rtol = rtol , atol = atol )
110
110
self .assertLen (encoded_x .shape , 2 )
111
111
# This is a rotation, hence, the norms should be the same.
112
112
# If the input has dimension 1, the transform is applied to the whole input.
113
113
# If the input has dimension 2, the transform is applied to every single
114
114
# vector separately.
115
115
if len (data .x .shape ) == 1 :
116
- self .assertAllClose (np .linalg .norm (data .x ), np .linalg .norm (encoded_x ))
116
+ self .assertAllClose (np .linalg .norm (data .x ), np .linalg .norm (encoded_x ),
117
+ rtol = rtol , atol = atol )
117
118
else :
118
119
for x , y in zip (data .x , encoded_x ):
119
- self .assertAllClose (np .linalg .norm (x ), np .linalg .norm (y ))
120
+ self .assertAllClose (np .linalg .norm (x ), np .linalg .norm (y ),
121
+ rtol = rtol , atol = atol )
120
122
121
123
def test_encoding_randomized (self ):
122
124
# The encoding stage declares a source of randomness (a random seed) in the
@@ -174,12 +176,15 @@ def get_random_shape_input():
174
176
self .assertEqual (test_data .x .shape [0 ], encoded_shape [0 ])
175
177
self .assertEqual (8 , encoded_shape [1 ])
176
178
177
- @parameterized .parameters ([tf .float32 , tf .float64 ])
178
- def test_input_types (self , x_dtype ):
179
+ @parameterized .parameters ([(tf .float16 , 1e-3 , 1e-3 ),
180
+ (tf .bfloat16 , 1e-1 , 1e-1 ),
181
+ (tf .float32 , 1e-6 , 1e-6 ),
182
+ (tf .float64 , 1e-6 , 1e-6 )])
183
+ def test_input_types (self , x_dtype , rtol , atol ):
179
184
test_data = self .run_one_to_many_encode_decode (
180
185
self .default_encoding_stage (),
181
186
lambda : tf .random .normal ([1 , 12 ], dtype = x_dtype ))
182
- self .common_asserts_for_test_data (test_data )
187
+ self .common_asserts_for_test_data (test_data , rtol = rtol , atol = atol )
183
188
184
189
def test_unknown_shape_raises (self ):
185
190
x = test_utils .get_tensor_with_random_shape ()
0 commit comments