Skip to content

Commit 6ece08b

Browse files
Johannes Ballécopybara-github
authored andcommitted
Improvements to entropy models.
Contains the following changes: - Makes the tail estimation stateless (not using `tf.Variable`s). - Implements Keras serialization of ContinuousBatchedEntropyModel. Serializing indexed entropy models currently can't be done, since they rely on code that isn't serializable (the callables that convert `indexes` into distribution parameters). - Adds an optimization that removes the quantization offset entirely if it happens to evaluate to zero. - Throws an error if an attempt is made to construct range coding tables when not in eager mode. - Uses tf.Module `name_scope`s where possible. PiperOrigin-RevId: 305906240 Change-Id: I516c9338c13fb8d21bff2ab72b5b5c39888932a5
1 parent 3b4998c commit 6ece08b

File tree

10 files changed

+379
-114
lines changed

10 files changed

+379
-114
lines changed

tensorflow_compression/python/distributions/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ py_test(
4242
name = "helpers_test",
4343
srcs = ["helpers_test.py"],
4444
python_version = "PY3",
45-
deps = [":helpers"],
45+
deps = [
46+
":deep_factorized",
47+
":helpers",
48+
],
4649
)
4750

4851
py_library(

tensorflow_compression/python/distributions/deep_factorized.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,15 @@ def _quantization_offset(self):
178178
return tf.constant(0, dtype=self.dtype)
179179

180180
def _lower_tail(self, tail_mass):
181-
tail = helpers.estimate_tail(
181+
tail = helpers.estimate_tails(
182182
self._logits_cumulative, -tf.math.log(2 / tail_mass - 1),
183-
[self.batch_shape.num_elements(), 1, 1], self.dtype)
183+
tf.constant([self.batch_shape.num_elements(), 1, 1], tf.int32),
184+
self.dtype)
184185
return tf.reshape(tail, self.batch_shape_tensor())
185186

186187
def _upper_tail(self, tail_mass):
187-
tail = helpers.estimate_tail(
188+
tail = helpers.estimate_tails(
188189
self._logits_cumulative, tf.math.log(2 / tail_mass - 1),
189-
[self.batch_shape.num_elements(), 1, 1], self.dtype)
190+
tf.constant([self.batch_shape.num_elements(), 1, 1], tf.int32),
191+
self.dtype)
190192
return tf.reshape(tail, self.batch_shape_tensor())

tensorflow_compression/python/distributions/helpers.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,71 @@
1919

2020

2121
__all__ = [
22+
"estimate_tails",
2223
"quantization_offset",
2324
"lower_tail",
2425
"upper_tail",
2526
]
2627

2728

28-
def estimate_tail(func, target, shape, dtype):
29-
"""Estimates approximate tail quantiles."""
30-
dtype = tf.as_dtype(dtype)
31-
shape = tf.convert_to_tensor(shape, tf.int32)
32-
target = tf.convert_to_tensor(target, dtype)
33-
opt = tf.keras.optimizers.Adam(learning_rate=.1)
34-
tails = tf.Variable(
35-
tf.zeros(shape, dtype=dtype), trainable=False, name="tails")
36-
loss = best_loss = tf.fill(shape, tf.constant(float("inf"), dtype=dtype))
37-
while tf.reduce_any(loss == best_loss):
38-
with tf.GradientTape(watch_accessed_variables=False) as tape:
39-
tape.watch(tails)
40-
loss = abs(func(tails) - target)
41-
best_loss = tf.minimum(best_loss, loss)
42-
gradient = tape.gradient(loss, tails)
43-
opt.apply_gradients([(gradient, tails)])
44-
return tails.value()
29+
# TODO(jonycgn): Consider wrapping in tf.function.
30+
def estimate_tails(func, target, shape, dtype):
31+
"""Estimates approximate tail quantiles.
32+
33+
This runs a simple Adam iteration to determine tail quantiles. The
34+
objective is to find an `x` such that:
35+
```
36+
func(x) == target
37+
```
38+
For instance, if `func` is a CDF and the target is a quantile value, this
39+
would find the approximate location of that quantile. Note that `func` is
40+
assumed to be monotonic. When each tail estimate has passed the optimal value
41+
of `x`, the algorithm does 10 additional iterations and then stops.
42+
43+
This operation is vectorized. The tensor shape of `x` is given by `shape`, and
44+
`target` must have a shape that is broadcastable to the output of `func(x)`.
45+
46+
Arguments:
47+
func: A callable that computes cumulative distribution function, survival
48+
function, or similar.
49+
target: The desired target value.
50+
shape: The shape of the `tf.Tensor` representing `x`.
51+
dtype: The `tf.dtypes.Dtype` of the computation (and the return value).
52+
53+
Returns:
54+
A `tf.Tensor` representing the solution (`x`).
55+
"""
56+
with tf.name_scope("estimate_tails"):
57+
dtype = tf.as_dtype(dtype)
58+
shape = tf.convert_to_tensor(shape, tf.int32)
59+
target = tf.convert_to_tensor(target, dtype)
60+
61+
def loop_cond(tails, m, v, count):
62+
del tails, m, v # unused
63+
return tf.reduce_min(count) < 10
64+
65+
def loop_body(tails, m, v, count):
66+
with tf.GradientTape(watch_accessed_variables=False) as tape:
67+
tape.watch(tails)
68+
loss = abs(func(tails) - target)
69+
grad = tape.gradient(loss, tails)
70+
m = .5 * m + .5 * grad # Adam mean estimate.
71+
v = .9 * v + .1 * tf.square(grad) # Adam variance estimate.
72+
tails -= .5 * m / (tf.sqrt(v) + 1e-7)
73+
# Start counting when the gradient flips sign (note that this assumes
74+
# `tails` is initialized to zero).
75+
count = tf.where(
76+
tf.math.logical_or(count > 0, tails * grad > 0),
77+
count + 1, count)
78+
return tails, m, v, count
79+
80+
init_tails = tf.zeros(shape, dtype=dtype)
81+
init_m = tf.zeros(shape, dtype=dtype)
82+
init_v = tf.ones(shape, dtype=dtype)
83+
init_count = tf.zeros(shape, dtype=tf.int32)
84+
return tf.while_loop(
85+
loop_cond, loop_body, (init_tails, init_m, init_v, init_count),
86+
back_prop=False)[0]
4587

4688

4789
def quantization_offset(distribution):
@@ -113,7 +155,7 @@ def lower_tail(distribution, tail_mass):
113155
tail = distribution.quantile(tail_mass / 2)
114156
except NotImplementedError:
115157
try:
116-
tail = estimate_tail(
158+
tail = estimate_tails(
117159
distribution.log_cdf, tf.math.log(tail_mass / 2),
118160
distribution.batch_shape_tensor(), distribution.dtype)
119161
except NotImplementedError:
@@ -149,7 +191,7 @@ def upper_tail(distribution, tail_mass):
149191
tail = distribution.quantile(1 - tail_mass / 2)
150192
except NotImplementedError:
151193
try:
152-
tail = estimate_tail(
194+
tail = estimate_tails(
153195
distribution.log_survival_function, tf.math.log(tail_mass / 2),
154196
distribution.batch_shape_tensor(), distribution.dtype)
155197
except NotImplementedError:

tensorflow_compression/python/distributions/helpers_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tensorflow.compat.v2 as tf
1818
import tensorflow_probability as tfp
1919

20+
from tensorflow_compression.python.distributions import deep_factorized
2021
from tensorflow_compression.python.distributions import helpers
2122

2223

@@ -71,6 +72,11 @@ def test_normal_tails_are_in_order(self):
7172
self.assertGreater(
7273
helpers.upper_tail(dist, 2**-8), helpers.lower_tail(dist, 2**-8))
7374

75+
def test_deep_factorized_tails_are_in_order(self):
76+
dist = deep_factorized.DeepFactorized(batch_shape=[10])
77+
self.assertAllGreater(
78+
helpers.upper_tail(dist, 2**-8) - helpers.lower_tail(dist, 2**-8), 0)
79+
7480

7581
if __name__ == "__main__":
7682
tf.test.main()

tensorflow_compression/python/entropy_models/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ py_library(
3131
srcs_version = "PY3",
3232
deps = [
3333
":continuous_base",
34+
"//tensorflow_compression/python/distributions:helpers",
3435
"//tensorflow_compression/python/ops:math_ops",
3536
"//tensorflow_compression/python/ops:range_coding_ops",
3637
],

0 commit comments

Comments
 (0)