Skip to content

Commit 40cf959

Browse files
Johannes Ballécopybara-github
authored andcommitted
Makes entropy models stateless.
The stateful implementation of the entropy models was broken. For instance, if the model was set up like this: ``` log_scale = tf.Variable() scale = tf.exp(log_scale) em = ContinuousBatchedEntropyModel(Normal(scale=scale)) ``` A subsequent training loop would not be able to backpropagate into `log_scale` in eager mode. It is less error prone to make entropy models stateless and re-instantiate them in the loop, like `tfp.distributions` objects. After training, the user would make another instance using `compression=True` and then share this model with sender and receiver. In addition, this renames the generic `distribution` attribute of the entropy model to `prior` (this doesn't necessarily indicate a Bayesian prior, but simply a distribution that is known `a priori` to both sender and receiver). PiperOrigin-RevId: 296313322 Change-Id: I464063b28dd1e3c0c7fa30324361ae604084bcce
1 parent 4713cd7 commit 40cf959

File tree

5 files changed

+208
-145
lines changed

5 files changed

+208
-145
lines changed

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,46 +37,79 @@ class ContinuousEntropyModelBase(tf.Module, metaclass=abc.ABCMeta):
3737
"""
3838

3939
@abc.abstractmethod
40-
def __init__(self, distribution, coding_rank,
40+
def __init__(self, prior, coding_rank, compression=False,
4141
likelihood_bound=1e-9, tail_mass=2**-8,
4242
range_coder_precision=12):
4343
"""Initializer.
4444
4545
Arguments:
46-
distribution: A `tfp.distributions.Distribution` object modeling the
47-
distribution of the input data including additive uniform noise. For
46+
prior: A `tfp.distributions.Distribution` object. A density model fitting
47+
the marginal distribution of the bottleneck data with additive uniform
48+
noise, which is shared a priori between the sender and the receiver. For
4849
best results, the distribution should be flexible enough to have a
49-
unit-width uniform distribution as a special case.
50+
unit-width uniform distribution as a special case, since this is the
51+
marginal distribution for bottleneck dimensions that are constant.
5052
coding_rank: Integer. Number of innermost dimensions considered a coding
5153
unit. Each coding unit is compressed to its own bit string, and the
5254
`bits()` method sums over each coding unit.
55+
compression: Boolean. If set to `True`, the range coding tables used by
56+
`compress()` and `decompress()` will be built on instantiation.
57+
Otherwise, some computation can be saved, but these two methods will not
58+
be accessible.
5359
likelihood_bound: Float. Lower bound for likelihood values, to prevent
5460
training instabilities.
5561
tail_mass: Float. Approximate probability mass which is range encoded with
5662
less precision, by using a Golomb-like code.
5763
range_coder_precision: Integer. Precision passed to the range coding op.
5864
"""
59-
if not distribution.is_scalar_event():
60-
raise ValueError(
61-
"`distribution` must be a (batch of) scalar distribution(s).")
65+
if not prior.is_scalar_event():
66+
raise ValueError("`prior` must be a (batch of) scalar distribution(s).")
6267
super().__init__()
63-
self._distribution = distribution
68+
self._prior = prior
6469
self._coding_rank = int(coding_rank)
70+
self._compression = bool(compression)
6571
self._likelihood_bound = float(likelihood_bound)
6672
self._tail_mass = float(tail_mass)
6773
self._range_coder_precision = int(range_coder_precision)
68-
self.update_tables()
74+
if self.compression:
75+
self._build_tables()
6976

7077
@property
71-
def distribution(self):
72-
"""Distribution modeling data + i.i.d. uniform noise."""
73-
return self._distribution
78+
def prior(self):
79+
"""Prior distribution, used for range coding."""
80+
return self._prior
81+
82+
def _check_compression(self):
83+
if not self.compression:
84+
raise RuntimeError(
85+
"To use range coding, the entropy model must be instantiated with "
86+
"`compression=True`.")
87+
88+
@property
89+
def cdf(self):
90+
self._check_compression()
91+
return self._cdf.value()
92+
93+
@property
94+
def cdf_offset(self):
95+
self._check_compression()
96+
return self._cdf_offset.value()
97+
98+
@property
99+
def cdf_length(self):
100+
self._check_compression()
101+
return self._cdf_length.value()
74102

75103
@property
76104
def coding_rank(self):
77105
"""Number of innermost dimensions considered a coding unit."""
78106
return self._coding_rank
79107

108+
@property
109+
def compression(self):
110+
"""Whether this entropy model is prepared for compression."""
111+
return self._compression
112+
80113
@property
81114
def likelihood_bound(self):
82115
"""Lower bound for likelihood values."""
@@ -94,27 +127,27 @@ def range_coder_precision(self):
94127

95128
@property
96129
def dtype(self):
97-
"""Data type of this distribution."""
98-
return self.distribution.dtype
130+
"""Data type of this entropy model."""
131+
return self.prior.dtype
99132

100133
def quantization_offset(self):
101134
"""Distribution-dependent quantization offset."""
102-
return helpers.quantization_offset(self.distribution)
135+
return helpers.quantization_offset(self.prior)
103136

104137
def lower_tail(self):
105138
"""Approximate lower tail quantile for range coding."""
106-
return helpers.lower_tail(self.distribution, self.tail_mass)
139+
return helpers.lower_tail(self.prior, self.tail_mass)
107140

108141
def upper_tail(self):
109142
"""Approximate upper tail quantile for range coding."""
110-
return helpers.upper_tail(self.distribution, self.tail_mass)
143+
return helpers.upper_tail(self.prior, self.tail_mass)
111144

112145
@tf.custom_gradient
113146
def _quantize(self, inputs, offset):
114147
return tf.round(inputs - offset) + offset, lambda x: (x, None)
115148

116-
def update_tables(self):
117-
"""Updates integer-valued probability tables used by the range coder.
149+
def _build_tables(self):
150+
"""Computes integer-valued probability tables used by the range coder.
118151
119152
These tables must not be re-generated independently on the sending and
120153
receiving side, since small numerical discrepancies between both sides can
@@ -126,9 +159,10 @@ def update_tables(self):
126159
> J. Ballé, N. Johnston, D. Minnen<br />
127160
> https://openreview.net/forum?id=S1zz2i0cY7
128161
129-
The tables are stored in `tf.Tensor`s as attributes of this object. The
130-
recommended way is to train the model, then call this method, and then
131-
distribute the model to a sender and a receiver.
162+
The tables are stored in `tf.Variable`s as attributes of this object. The
163+
recommended way is to train the model, instantiate an entropy model with
164+
`compression=True`, and then distribute the model to a sender and a
165+
receiver.
132166
"""
133167
offset = self.quantization_offset()
134168
lower_tail = self.lower_tail()
@@ -153,19 +187,19 @@ def update_tables(self):
153187
if max_length > 2048:
154188
logging.warning(
155189
"Very wide PMF with %d elements may lead to out of memory issues. "
156-
"Consider encoding distributions with smaller dispersion or "
157-
"increasing `tail_mass` parameter.", int(max_length))
190+
"Consider priors with smaller dispersion or increasing `tail_mass` "
191+
"parameter.", int(max_length))
158192
samples = tf.range(tf.cast(max_length, self.dtype), dtype=self.dtype)
159193
samples = tf.reshape(
160-
samples, [-1] + self.distribution.batch_shape.rank * [1])
194+
samples, [-1] + self.prior.batch_shape.rank * [1])
161195
samples += pmf_start
162-
pmf = self.distribution.prob(samples)
196+
pmf = self.prior.prob(samples)
163197

164198
# Collapse batch dimensions of distribution.
165199
pmf = tf.reshape(pmf, [max_length, -1])
166200
pmf = tf.transpose(pmf)
167201

168-
dist_shape = self.distribution.batch_shape_tensor()
202+
dist_shape = self.prior.batch_shape_tensor()
169203
pmf_length = tf.broadcast_to(pmf_length, dist_shape)
170204
pmf_length = tf.reshape(pmf_length, [-1])
171205
cdf_length = pmf_length + 2
@@ -187,4 +221,6 @@ def loop_body(args):
187221
cdf = tf.map_fn(
188222
loop_body, (pmf, pmf_length), dtype=tf.int32, name="pmf_to_cdf")
189223

190-
self._cdf, self._cdf_offset, self._cdf_length = cdf, cdf_offset, cdf_length
224+
self._cdf = tf.Variable(cdf, trainable=False)
225+
self._cdf_offset = tf.Variable(cdf_offset, trainable=False)
226+
self._cdf_length = tf.Variable(cdf_length, trainable=False)

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,22 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
3232
3333
This entropy model handles quantization of a bottleneck tensor and helps with
3434
training of the parameters of the probability distribution modeling the
35-
tensor. It also pre-computes integer probability tables, which can then be
36-
used to compress and decompress bottleneck tensors reliably across different
37-
platforms.
35+
tensor (a shared "prior" between sender and receiver). It also pre-computes
36+
integer probability tables, which can then be used to compress and decompress
37+
bottleneck tensors reliably across different platforms.
3838
3939
A typical workflow looks like this:
4040
41-
- Train a model using this entropy model as a bottleneck, passing the
42-
bottleneck tensor through `quantize()` while optimizing compressibility of
43-
the tensor using `bits()`. `bits(training=True)` computes a differentiable
44-
upper bound on the number of bits needed to compress the bottleneck tensor.
41+
- Train a model using an instance of this entropy model as a bottleneck,
42+
passing the bottleneck tensor through `quantize()` while optimizing
43+
compressibility of the tensor using `bits()`. `bits(training=True)` computes
44+
a differentiable upper bound on the number of bits needed to compress the
45+
bottleneck tensor.
4546
- For evaluation, get a closer estimate of the number of compressed bits
4647
using `bits(training=False)`.
47-
- Call `update_tables()` to ensure the probability tables for range coding are
48-
up-to-date.
49-
- Share the model between a sender and a receiver.
48+
- Instantiate an entropy model with `compression=True` (and the same
49+
parameters as during training), and share the model between a sender and a
50+
receiver.
5051
- On the sender side, compute the bottleneck tensor and call `compress()` on
5152
it. The output is a compressed string representation of the tensor. Transmit
5253
the string to the receiver, and call `decompress()` there. The output is the
@@ -56,9 +57,9 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
5657
This class assumes that all scalar elements of the encoded tensor are
5758
statistically independent, and that the parameters of their scalar
5859
distributions do not depend on data. The innermost dimensions of the
59-
bottleneck tensor must be broadcastable to the batch shape of `distribution`.
60-
Any dimensions to the left of the batch shape are assumed to be i.i.d., i.e.
61-
the likelihoods are broadcast to the bottleneck tensor accordingly.
60+
bottleneck tensor must be broadcastable to the batch shape of `prior`. Any
61+
dimensions to the left of the batch shape are assumed to be i.i.d., i.e. the
62+
likelihoods are broadcast to the bottleneck tensor accordingly.
6263
6364
A more detailed description (and motivation) of this way of performing
6465
quantization and range coding can be found in the following paper. Please cite
@@ -69,38 +70,44 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
6970
> https://openreview.net/forum?id=rJxdQ3jeg
7071
"""
7172

72-
def __init__(self, distribution, coding_rank,
73+
def __init__(self, prior, coding_rank, compression=False,
7374
likelihood_bound=1e-9, tail_mass=2**-8,
7475
range_coder_precision=12):
7576
"""Initializer.
7677
7778
Arguments:
78-
distribution: A `tfp.distributions.Distribution` object modeling the
79-
distribution of the bottleneck tensor values including additive uniform
80-
noise. The distribution parameters may not depend on data (they must be
81-
trainable variables or constants). For best results, the distribution
82-
should be flexible enough to have a unit-width uniform distribution as a
83-
special case, since this is the distribution an element will take on
84-
when its bottleneck value is constant (due to the additive noise).
79+
prior: A `tfp.distributions.Distribution` object. A density model fitting
80+
the marginal distribution of the bottleneck data with additive uniform
81+
noise, which is shared a priori between the sender and the receiver. For
82+
best results, the distribution should be flexible enough to have a
83+
unit-width uniform distribution as a special case, since this is the
84+
marginal distribution for bottleneck dimensions that are constant. The
85+
distribution parameters may not depend on data (they must be either
86+
variables or constants).
8587
coding_rank: Integer. Number of innermost dimensions considered a coding
8688
unit. Each coding unit is compressed to its own bit string, and the
8789
`bits()` method sums over each coding unit.
90+
compression: Boolean. If set to `True`, the range coding tables
91+
used by `compress()` and `decompress()` will be built on instantiation.
92+
Otherwise, some computation can be saved, but these two methods will not
93+
be accessible.
8894
likelihood_bound: Float. Lower bound for likelihood values, to prevent
8995
training instabilities.
9096
tail_mass: Float. Approximate probability mass which is range encoded with
9197
less precision, by using a Golomb-like code.
9298
range_coder_precision: Integer. Precision passed to the range coding op.
9399
"""
94-
if coding_rank < distribution.batch_shape.rank:
100+
if coding_rank < prior.batch_shape.rank:
95101
raise ValueError(
96-
"`coding_rank` can't be smaller than batch rank of `distribution`.")
102+
"`coding_rank` can't be smaller than batch rank of prior.")
97103
super().__init__(
98-
distribution, coding_rank, likelihood_bound=likelihood_bound,
99-
tail_mass=tail_mass, range_coder_precision=range_coder_precision)
104+
prior, coding_rank, compression=compression,
105+
likelihood_bound=likelihood_bound, tail_mass=tail_mass,
106+
range_coder_precision=range_coder_precision)
100107

101108
def _compute_indexes(self, broadcast_shape):
102109
# TODO(jonycgn, ssjhv): Investigate broadcasting in range coding op.
103-
dist_shape = self.distribution.batch_shape_tensor()
110+
dist_shape = self.prior.batch_shape_tensor()
104111
indexes = tf.range(tf.reduce_prod(dist_shape), dtype=tf.int32)
105112
indexes = tf.reshape(indexes, dist_shape)
106113
indexes = tf.broadcast_to(
@@ -113,9 +120,9 @@ def bits(self, bottleneck, training=True):
113120
Arguments:
114121
bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
115122
least `self.coding_rank` dimensions, and the innermost dimensions must
116-
be broadcastable to `self.distribution.batch_shape`.
123+
be broadcastable to `self.prior.batch_shape`.
117124
training: Boolean. If `False`, computes the Shannon information of
118-
`bottleneck` under the distribution `self.distribution`, which is a
125+
`bottleneck` under the distribution `self.prior`, which is a
119126
non-differentiable, tight *lower* bound on the number of bits needed to
120127
compress `bottleneck` using `compress()`. If `True`, returns a somewhat
121128
looser, but differentiable *upper* bound on this quantity.
@@ -129,7 +136,7 @@ def bits(self, bottleneck, training=True):
129136
tf.shape(bottleneck), minval=-.5, maxval=.5, dtype=bottleneck.dtype)
130137
else:
131138
quantized = self.quantize(bottleneck)
132-
probs = self.distribution.prob(quantized)
139+
probs = self.prior.prob(quantized)
133140
probs = math_ops.lower_bound(probs, self.likelihood_bound)
134141
axes = tuple(range(-self.coding_rank, 0))
135142
bits = tf.reduce_sum(tf.math.log(probs), axis=axes) / -tf.math.log(2.)
@@ -140,7 +147,7 @@ def quantize(self, bottleneck):
140147
141148
To use this entropy model as an information bottleneck during training, pass
142149
a tensor through this function. The tensor is rounded to integer values
143-
modulo `self.quantization_offset`, which depends on `self.distribution`. For
150+
modulo `self.quantization_offset`, which depends on `self.prior`. For
144151
instance, for a Gaussian distribution, the returned values are rounded to
145152
the location of the mode of the distribution plus or minus an integer.
146153
@@ -149,7 +156,7 @@ def quantize(self, bottleneck):
149156
150157
Arguments:
151158
bottleneck: `tf.Tensor` containing the data to be quantized. The innermost
152-
dimensions must be broadcastable to `self.distribution.batch_shape`.
159+
dimensions must be broadcastable to `self.prior.batch_shape`.
153160
154161
Returns:
155162
A `tf.Tensor` containing the quantized values.
@@ -162,7 +169,7 @@ def compress(self, bottleneck):
162169
163170
Compresses the tensor to bit strings. `bottleneck` is first quantized
164171
as in `quantize()`, and then compressed using the probability tables derived
165-
from `self.distribution`. The quantized tensor can later be recovered by
172+
from `self.prior`. The quantized tensor can later be recovered by
166173
calling `decompress()`.
167174
168175
The innermost `self.coding_rank` dimensions are treated as one coding unit,
@@ -172,7 +179,7 @@ def compress(self, bottleneck):
172179
Arguments:
173180
bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
174181
least `self.coding_rank` dimensions, and the innermost dimensions must
175-
be broadcastable to `self.distribution.batch_shape`.
182+
be broadcastable to `self.prior.batch_shape`.
176183
177184
Returns:
178185
A `tf.Tensor` having the same shape as `bottleneck` without the
@@ -184,7 +191,7 @@ def compress(self, bottleneck):
184191
batch_shape, coding_shape = tf.split(
185192
input_shape, [input_rank - self.coding_rank, self.coding_rank])
186193
broadcast_shape = coding_shape[
187-
:self.coding_rank - self.distribution.batch_shape.rank]
194+
:self.coding_rank - self.prior.batch_shape.rank]
188195

189196
indexes = self._compute_indexes(broadcast_shape)
190197
offset = self.quantization_offset()
@@ -196,7 +203,7 @@ def compress(self, bottleneck):
196203
def loop_body(symbols):
197204
return range_coding_ops.unbounded_index_range_encode(
198205
symbols, indexes,
199-
self._cdf, self._cdf_length, self._cdf_offset,
206+
self.cdf, self.cdf_length, self.cdf_offset,
200207
precision=self.range_coder_precision,
201208
overflow_width=4, debug_level=1)
202209

@@ -217,15 +224,15 @@ def decompress(self, strings, broadcast_shape):
217224
strings: `tf.Tensor` containing the compressed bit strings.
218225
broadcast_shape: Iterable of ints. The part of the output tensor shape
219226
between the shape of `strings` on the left and
220-
`self.distribution.batch_shape` on the right. This must match the shape
227+
`self.prior.batch_shape` on the right. This must match the shape
221228
of the input to `compress()`.
222229
223230
Returns:
224231
A `tf.Tensor` of shape `strings.shape + broadcast_shape +
225-
self.distribution.batch_shape`.
232+
self.prior.batch_shape`.
226233
"""
227234
batch_shape = tf.shape(strings)
228-
dist_shape = self.distribution.batch_shape_tensor()
235+
dist_shape = self.prior.batch_shape_tensor()
229236
symbols_shape = tf.concat([batch_shape, broadcast_shape, dist_shape], 0)
230237

231238
indexes = self._compute_indexes(broadcast_shape)
@@ -236,7 +243,7 @@ def decompress(self, strings, broadcast_shape):
236243
def loop_body(string):
237244
return range_coding_ops.unbounded_index_range_decode(
238245
string, indexes,
239-
self._cdf, self._cdf_length, self._cdf_offset,
246+
self.cdf, self.cdf_length, self.cdf_offset,
240247
precision=self.range_coder_precision,
241248
overflow_width=4, debug_level=1)
242249

0 commit comments

Comments
 (0)