Skip to content

Commit 17b39ee

Browse files
Add encoding type preservation support for float16 and bfloat16.
PiperOrigin-RevId: 515742571
1 parent 009dbdb commit 17b39ee

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/stages_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,10 +475,10 @@ def encode(self, x, encode_params):
475475
# type to be able to recover the type from encoded_tensors in decode method.
476476
if x.dtype == tf.float32:
477477
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, packed_x)])
478-
elif x.dtype == tf.float64:
478+
elif x.dtype in [tf.float16, tf.bfloat16, tf.float64]:
479479
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, packed_x),
480480
(self.DUMMY_TYPE_VALUES_KEY,
481-
tf.constant(0.0, dtype=tf.float64))])
481+
tf.constant(0.0, dtype=x.dtype))])
482482
else:
483483
raise TypeError(
484484
'Unsupported packing type: %s. Supported types are tf.float32 and '

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/stages_impl_test.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,23 @@ def is_lossless(self):
102102
"""See base class."""
103103
return True
104104

105-
def common_asserts_for_test_data(self, data):
105+
def common_asserts_for_test_data(self, data, rtol=1e-6, atol=1e-6):
106106
"""See base class."""
107107
encoded_x = data.encoded_x[
108108
stages_impl.HadamardEncodingStage.ENCODED_VALUES_KEY]
109-
self.assertAllClose(data.x, data.decoded_x)
109+
self.assertAllClose(data.x, data.decoded_x, rtol=rtol, atol=atol)
110110
self.assertLen(encoded_x.shape, 2)
111111
# This is a rotation, hence, the norms should be the same.
112112
# If the input has dimension 1, the transform is applied to the whole input.
113113
# If the input has dimension 2, the transform is applied to every single
114114
# vector separately.
115115
if len(data.x.shape) == 1:
116-
self.assertAllClose(np.linalg.norm(data.x), np.linalg.norm(encoded_x))
116+
self.assertAllClose(np.linalg.norm(data.x), np.linalg.norm(encoded_x),
117+
rtol=rtol, atol=atol)
117118
else:
118119
for x, y in zip(data.x, encoded_x):
119-
self.assertAllClose(np.linalg.norm(x), np.linalg.norm(y))
120+
self.assertAllClose(np.linalg.norm(x), np.linalg.norm(y),
121+
rtol=rtol, atol=atol)
120122

121123
def test_encoding_randomized(self):
122124
# The encoding stage declares a source of randomness (a random seed) in the
@@ -174,12 +176,15 @@ def get_random_shape_input():
174176
self.assertEqual(test_data.x.shape[0], encoded_shape[0])
175177
self.assertEqual(8, encoded_shape[1])
176178

177-
@parameterized.parameters([tf.float32, tf.float64])
178-
def test_input_types(self, x_dtype):
179+
@parameterized.parameters([(tf.float16, 1e-3, 1e-3),
180+
(tf.bfloat16, 1e-1, 1e-1),
181+
(tf.float32, 1e-6, 1e-6),
182+
(tf.float64, 1e-6, 1e-6)])
183+
def test_input_types(self, x_dtype, rtol, atol):
179184
test_data = self.run_one_to_many_encode_decode(
180185
self.default_encoding_stage(),
181186
lambda: tf.random.normal([1, 12], dtype=x_dtype))
182-
self.common_asserts_for_test_data(test_data)
187+
self.common_asserts_for_test_data(test_data, rtol=rtol, atol=atol)
183188

184189
def test_unknown_shape_raises(self):
185190
x = test_utils.get_tensor_with_random_shape()

0 commit comments

Comments
 (0)