@@ -102,21 +102,23 @@ def is_lossless(self):
102102 """See base class."""
103103 return True
104104
105- def common_asserts_for_test_data(self, data):
105+ def common_asserts_for_test_data(self, data, rtol=1e-6, atol=1e-6 ):
106106 """See base class."""
107107 encoded_x = data.encoded_x[
108108 stages_impl.HadamardEncodingStage.ENCODED_VALUES_KEY]
109- self.assertAllClose(data.x, data.decoded_x)
109+ self.assertAllClose(data.x, data.decoded_x, rtol=rtol, atol=atol )
110110 self.assertLen(encoded_x.shape, 2)
111111 # This is a rotation, hence, the norms should be the same.
112112 # If the input has dimension 1, the transform is applied to the whole input.
113113 # If the input has dimension 2, the transform is applied to every single
114114 # vector separately.
115115 if len(data.x.shape) == 1:
116- self.assertAllClose(np.linalg.norm(data.x), np.linalg.norm(encoded_x))
116+ self.assertAllClose(np.linalg.norm(data.x), np.linalg.norm(encoded_x),
117+ rtol=rtol, atol=atol)
117118 else:
118119 for x, y in zip(data.x, encoded_x):
119- self.assertAllClose(np.linalg.norm(x), np.linalg.norm(y))
120+ self.assertAllClose(np.linalg.norm(x), np.linalg.norm(y),
121+ rtol=rtol, atol=atol)
120122
121123 def test_encoding_randomized(self):
122124 # The encoding stage declares a source of randomness (a random seed) in the
@@ -174,12 +176,15 @@ def get_random_shape_input():
174176 self.assertEqual(test_data.x.shape[0], encoded_shape[0])
175177 self.assertEqual(8, encoded_shape[1])
176178
177- @parameterized.parameters([tf.float32, tf.float64])
178- def test_input_types(self, x_dtype):
179+ @parameterized.parameters([(tf.float16, 1e-3, 1e-3),
180+ (tf.bfloat16, 1e-1, 1e-1),
181+ (tf.float32, 1e-6, 1e-6),
182+ (tf.float64, 1e-6, 1e-6)])
183+ def test_input_types(self, x_dtype, rtol, atol):
179184 test_data = self.run_one_to_many_encode_decode(
180185 self.default_encoding_stage(),
181186 lambda: tf.random.normal([1, 12], dtype=x_dtype))
182- self.common_asserts_for_test_data(test_data)
187+ self.common_asserts_for_test_data(test_data, rtol=rtol, atol=atol )
183188
184189 def test_unknown_shape_raises(self):
185190 x = test_utils.get_tensor_with_random_shape()
0 commit comments