Skip to content

Commit 49fe704

Browse files
Johannes Ballécopybara-github
authored andcommitted
Revises logic for offset heuristic.
- Changes argument `non_integer_offset` to `offset_heuristic` to explicitly control whether the heuristic is used, not non-integer offsets in general. - There are three modes of operation: - If `quantization_offset` is provided manually (not `None`), these values are used and `offset_heuristic` is ineffective. - Otherwise, if `offset_heuristic and compression`, the offsets are computed once on initialization and then fixed. - Otherwise, if `offset_heuristic and not compression`, the offsets are recomputed every time quantization is performed. PiperOrigin-RevId: 424318816 Change-Id: I5bb4bda296476a139e556b7299f05396bca81302
1 parent eea8bf2 commit 49fe704

File tree

2 files changed

+81
-47
lines changed

2 files changed

+81
-47
lines changed

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 80 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Batched entropy model for continuous random variables."""
1616

1717
import functools
18+
from absl import logging
1819
import tensorflow as tf
1920
from tensorflow_compression.python.distributions import helpers
2021
from tensorflow_compression.python.entropy_models import continuous_base
@@ -55,20 +56,6 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
5556
quantized bottleneck tensor. Continue processing the tensor on the receiving
5657
side.
5758
58-
Entropy models which contain range coding tables (i.e. with
59-
`compression=True`) can be instantiated in three ways:
60-
61-
- By providing a continuous "prior" distribution object. The range coding
62-
tables are then derived from that continuous distribution.
63-
- From a config as returned by `get_config`, followed by a call to
64-
`set_weights`. This implements the Keras serialization protocol. In this
65-
case, the initializer creates empty state variables for the range coding
66-
tables, which are then filled by `set_weights`. As a consequence, this
67-
method requires `stateless=False`.
68-
- In a more low-level way, by directly providing the range coding tables to
69-
`__init__`, for use cases where the Keras protocol can't be used (e.g., when
70-
the entropy model must not create variables).
71-
7259
This class assumes that all scalar elements of the encoded tensor are
7360
statistically independent, and that the parameters of their scalar
7461
distributions do not depend on data. The innermost dimensions of the
@@ -83,6 +70,41 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
8370
> "End-to-end Optimized Image Compression"<br />
8471
> J. Ballé, V. Laparra, E.P. Simoncelli<br />
8572
> https://openreview.net/forum?id=rJxdQ3jeg
73+
74+
Entropy models which contain range coding tables (i.e. with
75+
`compression=True`) can be instantiated in three ways:
76+
77+
- By providing a continuous "prior" distribution object. The range coding
78+
tables are then derived from that continuous distribution.
79+
- From a config as returned by `get_config`, followed by a call to
80+
`set_weights`. This implements the Keras serialization protocol. In this
81+
case, the initializer creates empty state variables for the range coding
82+
tables, which are then filled by `set_weights`. As a consequence, this
83+
method requires `stateless=False`.
84+
- In a more low-level way, by directly providing the range coding tables to
85+
`__init__`, for use cases where the Keras protocol can't be used (e.g., when
86+
the entropy model must not create variables).
87+
88+
The `quantization_offset` and `offset_heuristic` arguments control whether
89+
quantization is performed with respect to integer values, or potentially
90+
non-integer offsets (i.e., `y = tf.round(x - o) + o`). There are three modes
91+
of operation:
92+
93+
- If `quantization_offset` is provided manually (not `None`), these values are
94+
used and `offset_heuristic` is ineffective.
95+
- Otherwise, if `offset_heuristic and compression`, the offsets are computed
96+
once on initialization, and then fixed. If the entropy model is serialized,
97+
they are preserved.
98+
- Otherwise, if `offset_heuristic and not compression`, the offsets are
99+
recomputed every time quantization is performed. Note this may be
100+
computationally expensive when the prior does not have a mode that is
101+
computable in closed form (e.g. for `NoisyDeepFactorized`).
102+
103+
This offset heuristic is discussed in Section III.A of:
104+
> "Nonlinear Transform Coding"<br />
105+
> J. Ballé, P.A. Chou, D. Minnen, S. Singh, N. Johnston, E. Agustsson,
106+
> S.J. Hwang, G. Toderici<br />
107+
> https://doi.org/10.1109/JSTSP.2020.3034501
86108
"""
87109

88110
def __init__(self,
@@ -98,7 +120,7 @@ def __init__(self,
98120
cdf=None,
99121
cdf_offset=None,
100122
cdf_shapes=None,
101-
non_integer_offset=True,
123+
offset_heuristic=True,
102124
quantization_offset=None,
103125
laplace_tail_mass=0):
104126
"""Initializes the instance.
@@ -141,11 +163,11 @@ def __init__(self,
141163
cdf_shapes: Shapes of `cdf` and `cdf_offset`. If provided, empty range
142164
coding tables are created, which can then be restored using
143165
`set_weights`. Requires `compression=True` and `stateless=False`.
144-
non_integer_offset: Boolean. Whether to quantize to non-integer offsets
166+
offset_heuristic: Boolean. Whether to quantize to non-integer offsets
145167
heuristically determined from mode/median of prior. Set this to `False`
146168
if you are using soft quantization during training.
147-
quantization_offset: `tf.Tensor` or `None`. If `cdf` is provided and
148-
`non_integer_offset=True`, must be provided as well.
169+
quantization_offset: `tf.Tensor` or `None`. The quantization offsets to
170+
use. If provided (not `None`), then `offset_heuristic` is ineffective.
149171
laplace_tail_mass: Float. If positive, will augment the prior with a
150172
Laplace mixture for training stability. (experimental)
151173
"""
@@ -171,33 +193,34 @@ def __init__(self,
171193
laplace_tail_mass=laplace_tail_mass,
172194
)
173195
self._prior = prior
174-
self._non_integer_offset = bool(non_integer_offset)
196+
self._offset_heuristic = bool(offset_heuristic)
175197
self._prior_shape = tf.TensorShape(
176198
prior_shape if prior is None else prior.batch_shape)
177199
if self.coding_rank < self.prior_shape.rank:
178200
raise ValueError("`coding_rank` can't be smaller than `prior_shape`.")
179201

180202
with self.name_scope:
181-
if quantization_offset is not None:
203+
if cdf_shapes is not None:
204+
# `cdf_shapes` being set indicates that we are using the `SavedModel`
205+
# protocol, which can only provide JSON datatypes. So create a
206+
# placeholder value depending on whether `quantization_offset` was
207+
# `None` or not. For this purpose, we expect a Boolean (when in all
208+
# other cases, we expect either `None` or a tensor).
209+
assert isinstance(quantization_offset, bool)
210+
assert self.compression
211+
if quantization_offset:
212+
quantization_offset = tf.zeros(
213+
self.prior_shape_tensor, dtype=self.dtype)
214+
else:
215+
quantization_offset = None
216+
elif quantization_offset is not None:
182217
# If quantization offset is passed in manually, use it.
183218
pass
184-
elif not self.non_integer_offset:
185-
# If not using the offset heuristic, always quantize to integers.
186-
quantization_offset = None
187-
elif cdf_shapes is not None:
188-
# `cdf_shapes` being set indicates that we are using the `SavedModel`
189-
# protocol. So create a placeholder value.
190-
quantization_offset = tf.zeros(
191-
self.prior_shape_tensor, dtype=self.dtype)
192-
elif cdf is not None:
193-
# CDF is passed in manually. So assume the same about the offsets.
194-
if quantization_offset is None:
219+
elif self.offset_heuristic and self.compression:
220+
# For compression, we need to fix the offset value, so compute it here.
221+
if self._prior is None:
195222
raise ValueError(
196-
"When providing `cdf` and `non_integer_offset=True`, must also "
197-
"provide `quantization_offset`.")
198-
else:
199-
assert self._prior is not None
200-
# If prior is available, determine offsets from it using the heuristic.
223+
"To use the offset heuristic, a `prior` needs to be provided.")
201224
quantization_offset = helpers.quantization_offset(self.prior)
202225
# Optimization: if the quantization offset is zero, we don't need to
203226
# subtract/add it when quantizing, and we don't need to serialize its
@@ -208,6 +231,8 @@ def __init__(self,
208231
else:
209232
quantization_offset = tf.broadcast_to(
210233
quantization_offset, self.prior_shape_tensor)
234+
else:
235+
quantization_offset = None
211236
if quantization_offset is None:
212237
self._quantization_offset = None
213238
elif self.compression and not self.stateless:
@@ -234,14 +259,25 @@ def prior_shape_tensor(self):
234259
return tf.constant(self.prior_shape.as_list(), dtype=tf.int32)
235260

236261
@property
237-
def non_integer_offset(self):
238-
return self._non_integer_offset
262+
def offset_heuristic(self):
263+
return self._offset_heuristic
239264

240265
@property
241266
def quantization_offset(self):
242-
if self._quantization_offset is None:
243-
return None
244-
return tf.convert_to_tensor(self._quantization_offset)
267+
if self._quantization_offset is not None:
268+
return tf.convert_to_tensor(self._quantization_offset)
269+
if self.offset_heuristic and not self.compression:
270+
if self._prior is None:
271+
raise RuntimeError(
272+
"To use the offset heuristic, a `prior` needs to be provided.")
273+
if not tf.executing_eagerly():
274+
logging.warning(
275+
"Computing quantization offsets using offset heuristic within a "
276+
"tf.function. Ideally, the offset heuristic should only be used "
277+
"to determine offsets once after training. Depending on the prior, "
278+
"estimating the offset might be computationally expensive.")
279+
return helpers.quantization_offset(self.prior)
280+
return None
245281

246282
@tf.Module.with_name_scope
247283
def __call__(self, bottleneck, training=True):
@@ -282,7 +318,7 @@ def quantize(self, bottleneck):
282318
The tensor is rounded to integer values potentially shifted by offsets (if
283319
`self.quantization_offset is not None`). These offsets can depend on
284320
`self.prior`. For instance, for a Gaussian distribution, when
285-
`self.non_integer_offset == True`, the returned values would be rounded
321+
`self.offset_heuristic == True`, the returned values would be rounded
286322
to the location of the mode of the distribution plus or minus an integer.
287323
288324
The gradient of this rounding operation is overridden with the identity
@@ -379,9 +415,7 @@ def get_config(self):
379415
config = super().get_config()
380416
config.update(
381417
prior_shape=tuple(map(int, self.prior_shape)),
382-
# Since the prior is never passed when using the `SavedModel` protocol,
383-
# we can reuse this flag to indicate whether the offsets need to be
384-
# loaded from a variable.
385-
non_integer_offset=self.quantization_offset is not None,
418+
offset_heuristic=self.offset_heuristic,
419+
quantization_offset=self.quantization_offset is not None,
386420
)
387421
return config

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def compress(self, values):
197197
def test_small_cdfs_for_dirac_prior_without_quantization_offset(self):
198198
prior = uniform_noise.NoisyNormal(loc=100. * tf.range(16.), scale=1e-10)
199199
em = ContinuousBatchedEntropyModel(
200-
prior, coding_rank=2, non_integer_offset=False, compression=True)
200+
prior, coding_rank=2, offset_heuristic=False, compression=True)
201201
self.assertEqual(em.cdf_offset.shape[0], 16)
202202
self.assertLessEqual(em.cdf.shape[0], 16 * 6)
203203

0 commit comments

Comments
 (0)