Skip to content

Commit 4e64806

Browse files
Johannes Ballécopybara-github
authored andcommitted
Addresses precision issues in unit tests on A100 platform.
This should fix #155 PiperOrigin-RevId: 487422439 Change-Id: I612ad5e2fdaed1e50bc36fcb17723bfa6ea4d0a1
1 parent 892e137 commit 4e64806

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

tensorflow_compression/python/distributions/deep_factorized_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323

2424
class DeepFactorizedTest(tf.test.TestCase, parameterized.TestCase):
2525

26+
def setUp(self):
27+
super().setUp()
28+
# Disable TensorFloat-32 format on A100 platform, as precision is too low
29+
# for current test assertions.
30+
tf.config.experimental.enable_tensor_float_32_execution(False)
31+
2632
def test_can_instantiate_scalar(self):
2733
df = deep_factorized.DeepFactorized()
2834
self.assertEqual(df.batch_shape, ())
@@ -66,6 +72,12 @@ def test_broadcasts_correctly(self, method):
6672

6773
class NoisyDeepFactorizedTest(tf.test.TestCase):
6874

75+
def setUp(self):
76+
super().setUp()
77+
# Disable TensorFloat-32 format on A100 platform, as precision is too low
78+
# for current test assertions.
79+
tf.config.experimental.enable_tensor_float_32_execution(False)
80+
6981
def test_can_instantiate_and_run_scalar(self):
7082
df = deep_factorized.NoisyDeepFactorized(num_filters=(2, 3, 4))
7183
self.assertEqual(df.batch_shape, ())

tensorflow_compression/python/layers/gdn_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323

2424
class GDNTest(tf.test.TestCase, parameterized.TestCase):
2525

26+
def setUp(self):
27+
super().setUp()
28+
# Disable TensorFloat-32 format on A100 platform, as precision is too low
29+
# for current test assertions.
30+
tf.config.experimental.enable_tensor_float_32_execution(False)
31+
2632
def test_invalid_data_format_raises_error(self):
2733
with self.assertRaises(ValueError):
2834
gdn.GDN(data_format="NHWC")

0 commit comments

Comments
 (0)