Skip to content

Commit c0e0fd5

Browse files
Johannes Ballécopybara-github
authored andcommitted
Tunes estimate_tails.
Uses a new stopping criterion that doesn't assume anything about initialization, and hand-tweaks parameters to deal well with distributions that occur in practice. PiperOrigin-RevId: 340860445 Change-Id: I515b30cccb1e79703743b7d9f5aac1b1eb0fbacd
1 parent 9054647 commit c0e0fd5

File tree

1 file changed

+10
-9
lines changed
  • tensorflow_compression/python/distributions

1 file changed

+10
-9
lines changed

tensorflow_compression/python/distributions/helpers.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def estimate_tails(func, target, shape, dtype):
3737
For instance, if `func` is a CDF and the target is a quantile value, this
3838
would find the approximate location of that quantile. Note that `func` is
3939
assumed to be monotonic. When each tail estimate has passed the optimal value
40-
of `x`, the algorithm does 10 additional iterations and then stops.
40+
of `x`, the algorithm does 100 additional iterations and then stops.
4141
4242
This operation is vectorized. The tensor shape of `x` is given by `shape`, and
4343
`target` must have a shape that is broadcastable to the output of `func(x)`.
@@ -59,20 +59,21 @@ def estimate_tails(func, target, shape, dtype):
5959

6060
def loop_cond(tails, m, v, count):
6161
del tails, m, v # unused
62-
return tf.reduce_min(count) < 10
62+
return tf.reduce_min(count) < 100
6363

64-
def loop_body(tails, m, v, count):
64+
def loop_body(tails, prev_m, prev_v, count):
6565
with tf.GradientTape(watch_accessed_variables=False) as tape:
6666
tape.watch(tails)
6767
loss = abs(func(tails) - target)
6868
grad = tape.gradient(loss, tails)
69-
m = .5 * m + .5 * grad # Adam mean estimate.
70-
v = .9 * v + .1 * tf.square(grad) # Adam variance estimate.
71-
tails -= .5 * m / (tf.sqrt(v) + 1e-7)
72-
# Start counting when the gradient flips sign (note that this assumes
73-
# `tails` is initialized to zero).
69+
m = (prev_m + grad) / 2 # Adam mean estimate.
70+
v = (prev_v + tf.square(grad)) / 2 # Adam variance estimate.
71+
tails -= .1 * m / (tf.sqrt(v) + 1e-20)
72+
# Start counting when the gradient flips sign. Since the function is
73+
# monotonic, m must have the same sign in all initial iterations, until
74+
# the optimal point is crossed. At that point the gradient flips sign.
7475
count = tf.where(
75-
tf.math.logical_or(count > 0, tails * grad > 0),
76+
tf.math.logical_or(count > 0, prev_m * grad < 0),
7677
count + 1, count)
7778
return tails, m, v, count
7879

0 commit comments

Comments
 (0)