Skip to content

Commit 9918bee

Browse files
ssjhvcopybara-github
authored andcommitted
Fixed mixed precision setting in test.
PiperOrigin-RevId: 592423175 Change-Id: I83bee20c5cda13beab8d1a9191d0f1bf7d178bf3
1 parent 3a43109 commit 9918bee

File tree

5 files changed

+5
-5
lines changed

5 files changed

+5
-5
lines changed

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_dtypes_are_correct_with_mixed_precision(self):
215215
self.assertEqual(bits.shape, (2,))
216216
self.assertAllGreaterEqual(bits, 0.)
217217
finally:
218-
tf.keras.mixed_precision.set_global_policy(None)
218+
tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx())
219219

220220
def test_small_cdfs_for_dirac_prior_without_quantization_offset(self):
221221
prior = uniform_noise.NoisyNormal(loc=100. * tf.range(16.), scale=1e-10)

tensorflow_compression/python/entropy_models/continuous_indexed_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def test_dtypes_are_correct_with_mixed_precision(self):
220220
self.assertEqual(bits.shape, (2,))
221221
self.assertAllGreaterEqual(bits, 0.)
222222
finally:
223-
tf.keras.mixed_precision.set_global_policy(None)
223+
tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx())
224224

225225

226226
class LocationScaleIndexedEntropyModelTest(tf.test.TestCase):

tensorflow_compression/python/entropy_models/power_law_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_dtypes_are_correct_with_mixed_precision(self):
115115
self.assertEqual(penalty.dtype, tf.float16)
116116
self.assertEqual(penalty.shape, (2,))
117117
finally:
118-
tf.keras.mixed_precision.set_global_policy(None)
118+
tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx())
119119

120120

121121
if __name__ == "__main__":

tensorflow_compression/python/layers/gdn_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_dtypes_are_correct_with_mixed_precision(self):
207207
self.assertEqual(variable.dtype, tf.float32)
208208
self.assertEqual(y.dtype, tf.float16)
209209
finally:
210-
tf.keras.mixed_precision.set_global_policy(None)
210+
tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx())
211211

212212

213213
if __name__ == "__main__":

tensorflow_compression/python/layers/signal_conv_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_dtypes_are_correct_with_mixed_precision(self):
162162
self.assertEqual(variable.dtype, tf.float32)
163163
self.assertEqual(y.dtype, tf.float16)
164164
finally:
165-
tf.keras.mixed_precision.set_global_policy(None)
165+
tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx())
166166

167167

168168
class ConvolutionsTest(tf.test.TestCase):

0 commit comments

Comments
 (0)