Skip to content

Commit d42203e

Browse files
Johannes Ballécopybara-github
authored andcommitted
Enables instantiating entropy model in graph mode with compression=True.
PiperOrigin-RevId: 314736009 Change-Id: I02322aa86358d428eb7155ad7f3818c53e1a1f17
1 parent 8692f3c commit d42203e

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,8 @@ def __init__(self, prior, coding_rank, compression=False,
5454
unit. Each coding unit is compressed to its own bit string, and the
5555
`bits()` method sums over each coding unit.
5656
compression: Boolean. If set to `True`, the range coding tables used by
57-
`compress()` and `decompress()` will be built on instantiation. This
58-
assumes eager mode (throws an error if in graph mode or inside a
59-
`tf.function` call). If set to `False`, these two methods will not be
60-
accessible.
57+
`compress()` and `decompress()` will be built on instantiation. If set
58+
to `False`, these two methods will not be accessible.
6159
likelihood_bound: Float. Lower bound for likelihood values, to prevent
6260
training instabilities.
6361
tail_mass: Float. Approximate probability mass which is range encoded with
@@ -81,8 +79,6 @@ def __init__(self, prior, coding_rank, compression=False,
8179
self._tail_mass = float(tail_mass)
8280
self._range_coder_precision = int(range_coder_precision)
8381
if self.compression:
84-
if not tf.executing_eagerly():
85-
raise RuntimeError("`compression=True` requires eager execution.")
8682
self._build_tables(prior)
8783

8884
@property
@@ -207,7 +203,7 @@ def _build_tables(self, prior):
207203
# Sample the densities in the computed ranges, possibly computing more
208204
# samples than necessary at the upper end.
209205
max_length = tf.math.reduce_max(pmf_length)
210-
if max_length > 2048:
206+
if tf.executing_eagerly() and max_length > 2048:
211207
logging.warning(
212208
"Very wide PMF with %d elements may lead to out of memory issues. "
213209
"Consider priors with smaller dispersion or increasing `tail_mass` "

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,8 @@ def __init__(self, prior, coding_rank, compression=False,
9292
unit. Each coding unit is compressed to its own bit string, and the
9393
`bits()` method sums over each coding unit.
9494
compression: Boolean. If set to `True`, the range coding tables used by
95-
`compress()` and `decompress()` will be built on instantiation. This
96-
assumes eager mode (throws an error if in graph mode or inside a
97-
`tf.function` call). If set to `False`, these two methods will not be
98-
accessible.
95+
`compress()` and `decompress()` will be built on instantiation. If set
96+
to `False`, these two methods will not be accessible.
9997
likelihood_bound: Float. Lower bound for likelihood values, to prevent
10098
training instabilities.
10199
tail_mass: Float. Approximate probability mass which is range encoded with
@@ -119,7 +117,10 @@ def __init__(self, prior, coding_rank, compression=False,
119117
# Optimization: if the quantization offset is zero, we don't need to
120118
# subtract/add it when quantizing, and we don't need to serialize its
121119
# value. Note that this code will only work in eager mode.
122-
if tf.reduce_all(tf.equal(quantization_offset, 0.)):
120+
# TODO(jonycgn): Reconsider if this optimization is worth keeping once
121+
# the implementation is stable.
122+
if tf.executing_eagerly() and tf.reduce_all(
123+
tf.equal(quantization_offset, 0.)):
123124
quantization_offset = None
124125
else:
125126
quantization_offset = tf.broadcast_to(
@@ -260,6 +261,8 @@ def decompress(self, strings, broadcast_shape):
260261
A `tf.Tensor` of shape `strings.shape + broadcast_shape +
261262
self.prior_shape`.
262263
"""
264+
strings = tf.convert_to_tensor(strings, dtype=tf.string)
265+
broadcast_shape = tf.convert_to_tensor(broadcast_shape, dtype=tf.int32)
263266
batch_shape = tf.shape(strings)
264267
symbols_shape = tf.concat(
265268
[batch_shape, broadcast_shape, self.prior_shape], 0)

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,5 +154,28 @@ def test_compression_works_after_serialization_no_offset(self):
154154
self.assertAllEqual(em.compress(x), x_compressed)
155155
self.assertAllEqual(em.decompress(x_compressed, [100]), x_quantized)
156156

157+
def test_compression_works_in_tf_function(self):
158+
noisy = uniform_noise.NoisyNormal(loc=0, scale=5.)
159+
sample = noisy.base.sample([100])
160+
161+
# Since tf.function traces each function twice, and only allows variable
162+
# creation in the first call, we need to have a stateful object in which we
163+
# create the entropy model only the first time the function is called, and
164+
# store it for the second time.
165+
166+
class Compressor(object):
167+
168+
def compress(self, values):
169+
if not hasattr(self, "em"):
170+
self.em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
171+
compressed = self.em.compress(values)
172+
decompressed = self.em.decompress(compressed, [])
173+
return decompressed
174+
175+
values_eager = Compressor().compress(sample)
176+
values_function = tf.function(Compressor().compress)(sample)
177+
self.assertAllEqual(values_eager, values_function)
178+
179+
157180
if __name__ == "__main__":
158181
tf.test.main()

0 commit comments

Comments
 (0)