Skip to content

Commit de969cd

Browse files
Johannes Ballécopybara-github
authored andcommitted
Re-implements entropy models using new range coder.
- Makes the ContinuousBatchedEntropyModel class more memory efficient, as it doesn't need to create an index table in memory any more. - Gets rid of the TF control flow in `compress()` and `decompress()`, since the new range coder can handle multiple streams. - Makes the CDFs more compact. Rather than a zero-padded 2D Tensor with `cdf_length` specifying the length of each CDF, now uses the packed format implemented in the range coder. So `cdf_length` is not needed any more, and both `cdf` and `cdf_offset` are now always 1D Tensors. - Uses the argument `cdf_shapes` to specify the shape of both `cdf` and `cdf_offset` for creating placeholders when using the `SavedModel` protocol. PiperOrigin-RevId: 421637496 Change-Id: Id61de16644694de4ff8ce505d57c420529113124
1 parent 61e7977 commit de969cd

File tree

7 files changed

+169
-343
lines changed

7 files changed

+169
-343
lines changed

tensorflow_compression/python/distributions/helpers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def lower_tail(distribution, tail_mass):
139139
number of symbols. This method returns a cut-off location for the lower
140140
tail, such that approximately `tail_mass` probability mass is contained in
141141
the tails (together). The tails are then handled by using the 'overflow'
142-
functionality of the range coder implementation (using a Golomb-like
143-
universal code).
142+
functionality of the range coder implementation (using an Elias gamma code).
144143
145144
Args:
146145
distribution: A `tfp.distributions.Distribution` object.
@@ -176,8 +175,7 @@ def upper_tail(distribution, tail_mass):
176175
number of symbols. This method returns a cut-off location for the upper
177176
tail, such that approximately `tail_mass` probability mass is contained in
178177
the tails (together). The tails are then handled by using the 'overflow'
179-
functionality of the range coder implementation (using a Golomb-like
180-
universal code).
178+
functionality of the range coder implementation (using an Elias gamma code).
181179
182180
Args:
183181
distribution: A `tfp.distributions.Distribution` object.

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 61 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def __init__(self,
4343
stateless=False,
4444
expected_grads=False,
4545
tail_mass=2**-8,
46-
range_coder_precision=12,
4746
dtype=None,
4847
laplace_tail_mass=0):
4948
"""Initializes the instance.
@@ -65,12 +64,11 @@ def __init__(self,
6564
`stateless=True` is implied and the provided value is ignored.
6665
expected_grads: If True, will use analytical expected gradients during
6766
backpropagation w.r.t. additive uniform noise.
68-
tail_mass: Float. Approximate probability mass which is range encoded with
69-
less precision, by using a Golomb-like code.
70-
range_coder_precision: Integer. Precision passed to the range coding op.
67+
tail_mass: Float. Approximate probability mass which is encoded using an
68+
Elias gamma code embedded into the range coder.
7169
dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of
7270
prior, decompressed values).
73-
laplace_tail_mass: Float. If positive, will augment the prior with a
71+
laplace_tail_mass: Float. If non-zero, will augment the prior with a
7472
Laplace mixture for training stability. (experimental)
7573
"""
7674
super().__init__()
@@ -80,11 +78,18 @@ def __init__(self,
8078
self._stateless = bool(stateless)
8179
self._expected_grads = bool(expected_grads)
8280
self._tail_mass = float(tail_mass)
83-
self._range_coder_precision = int(range_coder_precision)
8481
self._dtype = tf.as_dtype(dtype)
8582
self._laplace_tail_mass = float(laplace_tail_mass)
83+
84+
if self.coding_rank < 0:
85+
raise ValueError("`coding_rank` must be at least 0.")
86+
if not 0 < self.tail_mass < 1:
87+
raise ValueError("`tail_mass` must be between 0 and 1.")
88+
if not 0 <= self.laplace_tail_mass < 1:
89+
raise ValueError("`laplace_tail_mass` must be between 0 and 1.")
90+
8691
with self.name_scope:
87-
self._laplace_prior = (tfp.distributions.Laplace(loc=0.0, scale=1.0)
92+
self._laplace_prior = (tfp.distributions.Laplace(loc=0., scale=1.)
8893
if laplace_tail_mass else None)
8994

9095
def _check_compression(self):
@@ -117,11 +122,6 @@ def cdf_offset(self):
117122
self._check_compression()
118123
return tf.convert_to_tensor(self._cdf_offset)
119124

120-
@property
121-
def cdf_length(self):
122-
self._check_compression()
123-
return tf.convert_to_tensor(self._cdf_length)
124-
125125
@property
126126
def dtype(self):
127127
"""Data type of this entropy model."""
@@ -159,16 +159,16 @@ def tail_mass(self):
159159

160160
@property
161161
def range_coder_precision(self):
162-
"""Precision passed to range coding op."""
163-
return self._range_coder_precision
162+
"""Precision used in range coding op."""
163+
return -self.cdf[0]
164164

165-
def _init_compression(self, cdf, cdf_offset, cdf_length, cdf_shape):
165+
def _init_compression(self, cdf, cdf_offset, cdf_shapes):
166166
"""Sets up this entropy model for using the range coder.
167167
168-
This is done by storing `cdf`, `cdf_offset`, and `cdf_length` in
169-
`tf.Variable`s (`stateless=False`) or `tf.Tensor`s (`stateless=True`) as
170-
attributes of this object, or creating the variables as placeholders if
171-
`cdf_shape` is provided.
168+
This is done by storing `cdf` and `cdf_offset` in `tf.Variable`s
169+
(`stateless=False`) or `tf.Tensor`s (`stateless=True`) as attributes of this
170+
object, or creating the variables as placeholders if `cdf_shapes` is
171+
provided.
172172
173173
The reason for pre-computing the tables is that they must not be
174174
re-generated independently on the sending and receiving side, since small
@@ -184,41 +184,33 @@ def _init_compression(self, cdf, cdf_offset, cdf_length, cdf_shape):
184184
Args:
185185
cdf: CDF table for range coder.
186186
cdf_offset: CDF offset table for range coder.
187-
cdf_length: CDF length table for range coder.
188-
cdf_shape: Iterable of 2 integers, the shape of `cdf`. Mutually exclusive
189-
with the other three arguments. If provided, creates placeholder values
190-
for them.
187+
cdf_shapes: Iterable of integers, the shapes of `cdf` and `cdf_offset`.
188+
Mutually exclusive with the other two arguments. If provided, creates
189+
placeholder values for them.
191190
"""
192-
if not ((cdf is None) == (cdf_offset is None) == (cdf_length is None) ==
193-
(cdf_shape is not None)):
191+
if not (cdf is None) == (cdf_offset is None) == (cdf_shapes is not None):
194192
raise ValueError(
195-
"Either all of `cdf`, `cdf_offset`, and `cdf_length`; or `cdf_shape` "
196-
"must be provided.")
197-
if cdf_shape is not None:
193+
"Either both `cdf` and `cdf_offset`, or `cdf_shapes` must be "
194+
"provided.")
195+
if cdf_shapes is not None:
198196
if self.stateless:
199-
raise ValueError("With `stateless=True`, can't provide `cdf_shape`.")
200-
cdf_shape = tuple(map(int, cdf_shape))
201-
if len(cdf_shape) != 2:
202-
raise ValueError("`cdf_shape` must consist of 2 integers.")
203-
zeros = tf.zeros(cdf_shape, dtype=tf.int32)
204-
cdf = zeros
205-
cdf_offset = zeros[:, 0]
206-
cdf_length = zeros[:, 0]
197+
raise ValueError("With `stateless=True`, can't provide `cdf_shapes`.")
198+
cdf_shapes = tuple(map(int, cdf_shapes))
199+
if len(cdf_shapes) != 2:
200+
raise ValueError("`cdf_shapes` must have two elements.")
201+
cdf = tf.zeros(cdf_shapes[:1], dtype=tf.int32)
202+
cdf_offset = tf.zeros(cdf_shapes[1:], dtype=tf.int32)
207203
if self.stateless:
208204
self._cdf = tf.convert_to_tensor(cdf, dtype=tf.int32, name="cdf")
209205
self._cdf_offset = tf.convert_to_tensor(
210206
cdf_offset, dtype=tf.int32, name="cdf_offset")
211-
self._cdf_length = tf.convert_to_tensor(
212-
cdf_length, dtype=tf.int32, name="cdf_length")
213207
else:
214208
self._cdf = tf.Variable(
215209
cdf, dtype=tf.int32, trainable=False, name="cdf")
216210
self._cdf_offset = tf.Variable(
217211
cdf_offset, dtype=tf.int32, trainable=False, name="cdf_offset")
218-
self._cdf_length = tf.Variable(
219-
cdf_length, dtype=tf.int32, trainable=False, name="cdf_length")
220212

221-
def _build_tables(self, prior, offset=None, context_shape=None):
213+
def _build_tables(self, prior, precision, offset=None):
222214
"""Computes integer-valued probability tables used by the range coder.
223215
224216
These tables must not be re-generated independently on the sending and
@@ -233,18 +225,16 @@ def _build_tables(self, prior, offset=None, context_shape=None):
233225
234226
Args:
235227
prior: The `tfp.distributions.Distribution` object (see initializer).
236-
offset: Quantization offsets to use for sampling prior probabilities.
237-
Defaults to 0.
238-
context_shape: Shape of innermost dimensions to evaluate the prior on.
239-
Defaults to and must include `prior.batch_shape`.
228+
precision: Integer. Precision for range coder.
229+
offset: None or float tensor between -.5 and +.5. Sub-integer quantization
230+
offsets to use for sampling prior probabilities. Defaults to 0.
240231
241232
Returns:
242233
CDF table, CDF offsets, CDF lengths.
243234
"""
235+
precision = int(precision)
244236
if offset is None:
245237
offset = 0.
246-
if context_shape is None:
247-
context_shape = tf.TensorShape(prior.batch_shape)
248238
# Subclasses should have already caught this, but better be safe.
249239
assert not prior.event_shape.rank
250240

@@ -269,38 +259,38 @@ def _build_tables(self, prior, offset=None, context_shape=None):
269259
"Consider priors with smaller variance, or increasing `tail_mass` "
270260
"parameter.", int(max_length))
271261
samples = tf.range(tf.cast(max_length, self.dtype), dtype=self.dtype)
272-
samples = tf.reshape(samples, [-1] + context_shape.rank * [1])
262+
samples = tf.reshape(samples, [-1] + pmf_length.shape.rank * [1])
273263
samples += pmf_start
274264
pmf = prior.prob(samples)
265+
pmf_shape = tf.shape(pmf)[1:]
266+
num_pmfs = tf.reduce_prod(pmf_shape)
275267

276268
# Collapse batch dimensions of distribution.
277-
pmf = tf.reshape(pmf, [max_length, -1])
269+
pmf = tf.reshape(pmf, [max_length, num_pmfs])
278270
pmf = tf.transpose(pmf)
279271

280-
context_shape = tf.constant(context_shape.as_list(), dtype=tf.int32)
281-
pmf_length = tf.broadcast_to(pmf_length, context_shape)
282-
pmf_length = tf.reshape(pmf_length, [-1])
283-
cdf_length = pmf_length + 2
284-
cdf_offset = tf.broadcast_to(minima, context_shape)
285-
cdf_offset = tf.reshape(cdf_offset, [-1])
272+
pmf_length = tf.broadcast_to(pmf_length, pmf_shape)
273+
pmf_length = tf.reshape(pmf_length, [num_pmfs])
274+
cdf_offset = tf.broadcast_to(minima, pmf_shape)
275+
cdf_offset = tf.reshape(cdf_offset, [num_pmfs])
276+
precision_tensor = tf.constant([-precision], dtype=tf.int32)
286277

287278
# Prevent tensors from bouncing back and forth between host and GPU.
288279
with tf.device("/cpu:0"):
289-
def loop_body(args):
290-
prob, length = args
291-
prob = prob[:length]
292-
overflow = tf.math.maximum(1 - tf.reduce_sum(prob, keepdims=True), 0.)
293-
prob = tf.concat([prob, overflow], axis=0)
294-
cdf = gen_ops.pmf_to_quantized_cdf(
295-
tf.cast(prob, tf.float32), precision=self.range_coder_precision)
296-
return tf.pad(
297-
cdf, [[0, max_length - length]], mode="CONSTANT", constant_values=0)
298-
299-
# TODO(jonycgn,ssjhv): Consider switching to Python control flow.
300-
cdf = tf.map_fn(
301-
loop_body, (pmf, pmf_length), dtype=tf.int32, name="pmf_to_cdf")
302-
303-
return cdf, cdf_offset, cdf_length
280+
def loop_body(i, cdf):
281+
p = pmf[i, :pmf_length[i]]
282+
overflow = tf.math.maximum(1. - tf.reduce_sum(p, keepdims=True), 0.)
283+
p = tf.cast(tf.concat([p, overflow], 0), tf.float32)
284+
c = gen_ops.pmf_to_quantized_cdf(p, precision=precision)
285+
return i + 1, tf.concat([cdf, precision_tensor, c], 0)
286+
i_0 = tf.constant(0, tf.int32)
287+
cdf_0 = tf.constant([], tf.int32)
288+
_, cdf = tf.while_loop(
289+
lambda i, _: i < num_pmfs, loop_body, (i_0, cdf_0),
290+
shape_invariants=(i_0.shape, tf.TensorShape([None])),
291+
name="pmf_to_cdf")
292+
293+
return cdf, cdf_offset
304294

305295
def _log_prob(self, prior, bottleneck_perturbed):
306296
"""Evaluates prior.log_prob(bottleneck + noise)."""
@@ -341,8 +331,7 @@ def get_config(self):
341331
stateless=False,
342332
expected_grads=self.expected_grads,
343333
tail_mass=self.tail_mass,
344-
range_coder_precision=self.range_coder_precision,
345-
cdf_shape=tuple(map(int, self.cdf.shape)),
334+
cdf_shapes=(self.cdf.shape[0], self.cdf_offset.shape[0]),
346335
dtype=self.dtype.name,
347336
laplace_tail_mass=self.laplace_tail_mass,
348337
)

0 commit comments

Comments
 (0)