Skip to content

Commit e0bc5c5

Browse files
Johannes Ballécopybara-github
authored andcommitted
Removes numerical tweaks for soft_round.
It seems they are mostly not necessary, except for one in soft_round_inverse. I changed this one to clip r instead of s, since it doesn't require picking constants (r is between -.5 and .5 by definition; we can simply enforce that). PiperOrigin-RevId: 341053012 Change-Id: I8ec84b2eb5fa8e2c8d6292cab8ac69878fc0746b
1 parent c0e0fd5 commit e0bc5c5

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

tensorflow_compression/python/ops/soft_round_ops.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
__all__ = ["soft_round", "soft_round_inverse", "soft_round_conditional_mean"]
2121

2222

23-
def soft_round(x, alpha, eps=1e-12):
23+
def soft_round(x, alpha, eps=1e-3):
2424
"""Differentiable approximation to round().
2525
2626
Larger alphas correspond to closer approximations of the round function.
@@ -39,28 +39,19 @@ def soft_round(x, alpha, eps=1e-12):
3939
Returns:
4040
tf.Tensor
4141
"""
42-
4342
if isinstance(alpha, (float, int)) and alpha < eps:
4443
return tf.identity(x, name="soft_round")
4544

46-
m = tf.floor(x) + 0.5
45+
m = tf.floor(x) + .5
4746
r = x - m
48-
z = tf.maximum(tf.tanh(alpha / 2.0) * 2.0, eps)
47+
z = tf.tanh(alpha / 2.) * 2.
4948
y = m + tf.tanh(alpha * r) / z
5049

5150
# For very low alphas, soft_round behaves like identity
5251
return tf.where(alpha < eps, x, y, name="soft_round")
5352

5453

55-
@tf.custom_gradient
56-
def _clip_st(s):
57-
"""Clip s to [-1 + 1e-7, 1 - 1e-7] with straight-through gradients."""
58-
s = tf.clip_by_value(s, -1 + 1e-7, 1 - 1e-7)
59-
grad = lambda x: x
60-
return s, grad
61-
62-
63-
def soft_round_inverse(y, alpha, eps=1e-12):
54+
def soft_round_inverse(y, alpha, eps=1e-3):
6455
"""Inverse of soft_round().
6556
6657
This is described in Sec. 4.1. in the paper
@@ -77,21 +68,19 @@ def soft_round_inverse(y, alpha, eps=1e-12):
7768
Returns:
7869
tf.Tensor
7970
"""
80-
8171
if isinstance(alpha, (float, int)) and alpha < eps:
8272
return tf.identity(y, name="soft_round_inverse")
8373

84-
m = tf.floor(y) + 0.5
85-
s = (y - m) * (tf.tanh(alpha / 2.0) * 2.0)
86-
# We have -0.5 <= (y-m) <= 0.5 and -1 < tanh < 1, so
87-
# -1 <= s <= 1. However tf.atanh is only stable for inputs
88-
# in the range [-1+1e-7, 1-1e-7], so we (safely) clip s to this range.
89-
# In the rare case where `1-|s| < 1e-7`, we use straight-through for the
90-
# gradient.
91-
s = _clip_st(s)
92-
r = tf.atanh(s) / tf.maximum(alpha, eps)
74+
m = tf.floor(y) + .5
75+
s = (y - m) * (tf.tanh(alpha / 2.) * 2.)
76+
r = tf.atanh(s) / alpha
77+
# `r` must be between -.5 and .5 by definition. In case atanh becomes +-inf
78+
# due to numerical instability, this prevents the forward pass from yielding
79+
# infinite values. Note that it doesn't prevent the backward pass from
80+
# returning non-finite values.
81+
r = tf.clip_by_value(r, -.5, .5)
9382

94-
# For very low alphas, soft_round behaves like identity
83+
# For very low alphas, soft_round behaves like identity.
9584
return tf.where(alpha < eps, y, m + r, name="soft_round_inverse")
9685

9786

@@ -107,12 +96,11 @@ def soft_round_conditional_mean(inputs, alpha):
10796
> Eirikur Agustsson & Lucas Theis<br />
10897
> https://arxiv.org/abs/2006.09952
10998
110-
11199
Args:
112100
inputs: The input tensor.
113101
alpha: The softround alpha.
114102
115103
Returns:
116104
The conditional mean, of same shape as `inputs`.
117105
"""
118-
return soft_round_inverse(inputs - 0.5, alpha) + 0.5
106+
return soft_round_inverse(inputs - .5, alpha) + .5

tensorflow_compression/python/ops/soft_round_ops_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
# ==============================================================================
1515
"""Tests for soft round."""
1616

17+
from absl.testing import parameterized
1718
import tensorflow as tf
1819

1920
from tensorflow_compression.python.ops import soft_round_ops
2021

2122

22-
class SoftRoundTest(tf.test.TestCase):
23+
class SoftRoundTest(tf.test.TestCase, parameterized.TestCase):
2324

2425
def test_soft_round_small_alpha_is_identity(self):
2526
x = tf.linspace(-2., 2., 50)
@@ -58,5 +59,32 @@ def test_conditional_mean_large_alpha_is_round(self):
5859
y = soft_round_ops.soft_round_conditional_mean(x, alpha=5000.0)
5960
self.assertAllClose(tf.math.round(x), y, atol=0.001)
6061

62+
@parameterized.parameters(0., 1e-6, 1e-2, 5., 1e6)
63+
def test_soft_round_values_and_gradients_are_finite(self, alpha):
64+
x = tf.linspace(0., 1., 11) # covers exact integers and half-integers
65+
with tf.GradientTape() as tape:
66+
tape.watch(x)
67+
y = soft_round_ops.soft_round(x, alpha=alpha)
68+
dy = tape.gradient(y, x)
69+
self.assertAllEqual(tf.math.is_finite(y), tf.ones(x.shape, dtype=bool))
70+
self.assertAllEqual(tf.math.is_finite(dy), tf.ones(x.shape, dtype=bool))
71+
72+
@parameterized.parameters(0., 1e-6, 1e-2, 5., 1e6)
73+
def test_soft_round_inverse_values_and_gradients_are_finite(self, alpha):
74+
x = tf.linspace(-.5, .5, 11) # covers exact integers and half-integers
75+
with tf.GradientTape() as tape:
76+
tape.watch(x)
77+
y = soft_round_ops.soft_round_inverse(x, alpha=alpha)
78+
dy = tape.gradient(y, x)
79+
self.assertAllEqual(tf.math.is_finite(y), tf.ones(x.shape, dtype=bool))
80+
if alpha > 15:
81+
# We allow non-finite values for large alphas, since the function simply
82+
# is extremely steep there.
83+
expected_finite = tf.one_hot(5, 11, False, True)
84+
else:
85+
expected_finite = tf.ones(x.shape, dtype=bool)
86+
self.assertAllEqual(tf.math.is_finite(dy), expected_finite)
87+
88+
6189
if __name__ == "__main__":
6290
tf.test.main()

0 commit comments

Comments
 (0)