Skip to content

Commit cf96d1b

Browse files
relationalcopybara-github
authored andcommitted
Remove special casing in soft-round for non-tensor alpha and prevent tf.where NaNs.
PiperOrigin-RevId: 342666795 Change-Id: I2352fba14a5fd5454158793fe5eb55dfc0e0f793
1 parent e5872ee commit cf96d1b

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tensorflow_compression/python/ops/soft_round_ops.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ def soft_round(x, alpha, eps=1e-3):
3939
Returns:
4040
tf.Tensor
4141
"""
42-
if isinstance(alpha, (float, int)) and alpha < eps:
43-
return tf.identity(x, name="soft_round")
42+
# This guards the gradient of tf.where below against NaNs, while maintaining
43+
# correctness, as for alpha < eps the result is ignored.
44+
alpha_bounded = tf.maximum(alpha, eps)
4445

4546
m = tf.floor(x) + .5
4647
r = x - m
47-
z = tf.tanh(alpha / 2.) * 2.
48-
y = m + tf.tanh(alpha * r) / z
48+
z = tf.tanh(alpha_bounded / 2.) * 2.
49+
y = m + tf.tanh(alpha_bounded * r) / z
4950

5051
# For very low alphas, soft_round behaves like identity
5152
return tf.where(alpha < eps, x, y, name="soft_round")
@@ -68,12 +69,12 @@ def soft_round_inverse(y, alpha, eps=1e-3):
6869
Returns:
6970
tf.Tensor
7071
"""
71-
if isinstance(alpha, (float, int)) and alpha < eps:
72-
return tf.identity(y, name="soft_round_inverse")
73-
72+
# This guards the gradient of tf.where below against NaNs, while maintaining
73+
# correctness, as for alpha < eps the result is ignored.
74+
alpha_bounded = tf.maximum(alpha, eps)
7475
m = tf.floor(y) + .5
75-
s = (y - m) * (tf.tanh(alpha / 2.) * 2.)
76-
r = tf.atanh(s) / alpha
76+
s = (y - m) * (tf.tanh(alpha_bounded / 2.) * 2.)
77+
r = tf.atanh(s) / alpha_bounded
7778
# `r` must be between -.5 and .5 by definition. In case atanh becomes +-inf
7879
# due to numerical instability, this prevents the forward pass from yielding
7980
# infinite values. Note that it doesn't prevent the backward pass from

0 commit comments

Comments
 (0)