Skip to content

Commit 7a6899c

Browse files
Johannes Ballécopybara-github
authored andcommitted
Fixes platform dependency in soft_round_ops unit test.
The gradient value of the inverse soft round function at 0 is allowed to evaluate as infinite, but can sometimes be finite depending on the platform. This change makes that point a don't-care value. PiperOrigin-RevId: 361500720 Change-Id: I0cf42414d41f67b66b5fd25f5e723a2f23324c64
1 parent bbe51f0 commit 7a6899c

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tensorflow_compression/python/ops/soft_round_ops_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ def test_soft_round_inverse_values_and_gradients_are_finite(self, alpha):
7777
y = soft_round_ops.soft_round_inverse(x, alpha=alpha)
7878
dy = tape.gradient(y, x)
7979
self.assertAllEqual(tf.math.is_finite(y), tf.ones(x.shape, dtype=bool))
80+
is_finite = tf.math.is_finite(dy)
81+
expected_finite = tf.ones(dy.shape, dtype=bool)
8082
if alpha > 15:
81-
# We allow non-finite values for large alphas, since the function simply
82-
# is extremely steep there.
83-
expected_finite = tf.one_hot(5, 11, False, True)
84-
else:
85-
expected_finite = tf.ones(x.shape, dtype=bool)
86-
self.assertAllEqual(tf.math.is_finite(dy), expected_finite)
83+
# We allow non-finite values at 0 for large alphas, since the function
84+
# simply is extremely steep there.
85+
expected_finite = tf.tensor_scatter_nd_update(
86+
expected_finite, [[5]], [is_finite[5]])
87+
self.assertAllEqual(is_finite, expected_finite)
8788

8889

8990
if __name__ == "__main__":

0 commit comments

Comments
 (0)