15
15
"""Batched entropy model for continuous random variables."""
16
16
17
17
import functools
18
+ from absl import logging
18
19
import tensorflow as tf
19
20
from tensorflow_compression .python .distributions import helpers
20
21
from tensorflow_compression .python .entropy_models import continuous_base
@@ -55,20 +56,6 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
55
56
quantized bottleneck tensor. Continue processing the tensor on the receiving
56
57
side.
57
58
58
- Entropy models which contain range coding tables (i.e. with
59
- `compression=True`) can be instantiated in three ways:
60
-
61
- - By providing a continuous "prior" distribution object. The range coding
62
- tables are then derived from that continuous distribution.
63
- - From a config as returned by `get_config`, followed by a call to
64
- `set_weights`. This implements the Keras serialization protocol. In this
65
- case, the initializer creates empty state variables for the range coding
66
- tables, which are then filled by `set_weights`. As a consequence, this
67
- method requires `stateless=False`.
68
- - In a more low-level way, by directly providing the range coding tables to
69
- `__init__`, for use cases where the Keras protocol can't be used (e.g., when
70
- the entropy model must not create variables).
71
-
72
59
This class assumes that all scalar elements of the encoded tensor are
73
60
statistically independent, and that the parameters of their scalar
74
61
distributions do not depend on data. The innermost dimensions of the
@@ -83,6 +70,41 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
83
70
> "End-to-end Optimized Image Compression"<br />
84
71
> J. Ballé, V. Laparra, E.P. Simoncelli<br />
85
72
> https://openreview.net/forum?id=rJxdQ3jeg
73
+
74
+ Entropy models which contain range coding tables (i.e. with
75
+ `compression=True`) can be instantiated in three ways:
76
+
77
+ - By providing a continuous "prior" distribution object. The range coding
78
+ tables are then derived from that continuous distribution.
79
+ - From a config as returned by `get_config`, followed by a call to
80
+ `set_weights`. This implements the Keras serialization protocol. In this
81
+ case, the initializer creates empty state variables for the range coding
82
+ tables, which are then filled by `set_weights`. As a consequence, this
83
+ method requires `stateless=False`.
84
+ - In a more low-level way, by directly providing the range coding tables to
85
+ `__init__`, for use cases where the Keras protocol can't be used (e.g., when
86
+ the entropy model must not create variables).
87
+
88
+ The `quantization_offset` and `offset_heuristic` arguments control whether
89
+ quantization is performed with respect to integer values, or potentially
90
+ non-integer offsets (i.e., `y = tf.round(x - o) + o`). There are three modes
91
+ of operation:
92
+
93
+ - If `quantization_offset` is provided manually (not `None`), these values are
94
+ used and `offset_heuristic` is ineffective.
95
+ - Otherwise, if `offset_heuristic and compression`, the offsets are computed
96
+ once on initialization, and then fixed. If the entropy model is serialized,
97
+ they are preserved.
98
+ - Otherwise, if `offset_heuristic and not compression`, the offsets are
99
+ recomputed every time quantization is performed. Note this may be
100
+ computationally expensive when the prior does not have a mode that is
101
+ computable in closed form (e.g. for `NoisyDeepFactorized`).
102
+
103
+ This offset heuristic is discussed in Section III.A of:
104
+ > "Nonlinear Transform Coding"<br />
105
+ > J. Ballé, P.A. Chou, D. Minnen, S. Singh, N. Johnston, E. Agustsson,
106
+ > S.J. Hwang, G. Toderici<br />
107
+ > https://doi.org/10.1109/JSTSP.2020.3034501
86
108
"""
87
109
88
110
def __init__ (self ,
@@ -98,7 +120,7 @@ def __init__(self,
98
120
cdf = None ,
99
121
cdf_offset = None ,
100
122
cdf_shapes = None ,
101
- non_integer_offset = True ,
123
+ offset_heuristic = True ,
102
124
quantization_offset = None ,
103
125
laplace_tail_mass = 0 ):
104
126
"""Initializes the instance.
@@ -141,11 +163,11 @@ def __init__(self,
141
163
cdf_shapes: Shapes of `cdf` and `cdf_offset`. If provided, empty range
142
164
coding tables are created, which can then be restored using
143
165
`set_weights`. Requires `compression=True` and `stateless=False`.
144
- non_integer_offset : Boolean. Whether to quantize to non-integer offsets
166
+ offset_heuristic : Boolean. Whether to quantize to non-integer offsets
145
167
heuristically determined from mode/median of prior. Set this to `False`
146
168
if you are using soft quantization during training.
147
- quantization_offset: `tf.Tensor` or `None`. If `cdf` is provided and
148
- `non_integer_offset=True`, must be provided as well .
169
+ quantization_offset: `tf.Tensor` or `None`. The quantization offsets to
170
+ use. If provided (not `None`), then `offset_heuristic` is ineffective .
149
171
laplace_tail_mass: Float. If positive, will augment the prior with a
150
172
Laplace mixture for training stability. (experimental)
151
173
"""
@@ -171,33 +193,34 @@ def __init__(self,
171
193
laplace_tail_mass = laplace_tail_mass ,
172
194
)
173
195
self ._prior = prior
174
- self ._non_integer_offset = bool (non_integer_offset )
196
+ self ._offset_heuristic = bool (offset_heuristic )
175
197
self ._prior_shape = tf .TensorShape (
176
198
prior_shape if prior is None else prior .batch_shape )
177
199
if self .coding_rank < self .prior_shape .rank :
178
200
raise ValueError ("`coding_rank` can't be smaller than `prior_shape`." )
179
201
180
202
with self .name_scope :
181
- if quantization_offset is not None :
203
+ if cdf_shapes is not None :
204
+ # `cdf_shapes` being set indicates that we are using the `SavedModel`
205
+ # protocol, which can only provide JSON datatypes. So create a
206
+ # placeholder value depending on whether `quantization_offset` was
207
+ # `None` or not. For this purpose, we expect a Boolean (when in all
208
+ # other cases, we expect either `None` or a tensor).
209
+ assert isinstance (quantization_offset , bool )
210
+ assert self .compression
211
+ if quantization_offset :
212
+ quantization_offset = tf .zeros (
213
+ self .prior_shape_tensor , dtype = self .dtype )
214
+ else :
215
+ quantization_offset = None
216
+ elif quantization_offset is not None :
182
217
# If quantization offset is passed in manually, use it.
183
218
pass
184
- elif not self .non_integer_offset :
185
- # If not using the offset heuristic, always quantize to integers.
186
- quantization_offset = None
187
- elif cdf_shapes is not None :
188
- # `cdf_shapes` being set indicates that we are using the `SavedModel`
189
- # protocol. So create a placeholder value.
190
- quantization_offset = tf .zeros (
191
- self .prior_shape_tensor , dtype = self .dtype )
192
- elif cdf is not None :
193
- # CDF is passed in manually. So assume the same about the offsets.
194
- if quantization_offset is None :
219
+ elif self .offset_heuristic and self .compression :
220
+ # For compression, we need to fix the offset value, so compute it here.
221
+ if self ._prior is None :
195
222
raise ValueError (
196
- "When providing `cdf` and `non_integer_offset=True`, must also "
197
- "provide `quantization_offset`." )
198
- else :
199
- assert self ._prior is not None
200
- # If prior is available, determine offsets from it using the heuristic.
223
+ "To use the offset heuristic, a `prior` needs to be provided." )
201
224
quantization_offset = helpers .quantization_offset (self .prior )
202
225
# Optimization: if the quantization offset is zero, we don't need to
203
226
# subtract/add it when quantizing, and we don't need to serialize its
@@ -208,6 +231,8 @@ def __init__(self,
208
231
else :
209
232
quantization_offset = tf .broadcast_to (
210
233
quantization_offset , self .prior_shape_tensor )
234
+ else :
235
+ quantization_offset = None
211
236
if quantization_offset is None :
212
237
self ._quantization_offset = None
213
238
elif self .compression and not self .stateless :
@@ -234,14 +259,25 @@ def prior_shape_tensor(self):
234
259
return tf .constant (self .prior_shape .as_list (), dtype = tf .int32 )
235
260
236
261
@property
237
- def non_integer_offset (self ):
238
- return self ._non_integer_offset
262
+ def offset_heuristic (self ):
263
+ return self ._offset_heuristic
239
264
240
265
@property
241
266
def quantization_offset (self ):
242
- if self ._quantization_offset is None :
243
- return None
244
- return tf .convert_to_tensor (self ._quantization_offset )
267
+ if self ._quantization_offset is not None :
268
+ return tf .convert_to_tensor (self ._quantization_offset )
269
+ if self .offset_heuristic and not self .compression :
270
+ if self ._prior is None :
271
+ raise RuntimeError (
272
+ "To use the offset heuristic, a `prior` needs to be provided." )
273
+ if not tf .executing_eagerly ():
274
+ logging .warning (
275
+ "Computing quantization offsets using offset heuristic within a "
276
+ "tf.function. Ideally, the offset heuristic should only be used "
277
+ "to determine offsets once after training. Depending on the prior, "
278
+ "estimating the offset might be computationally expensive." )
279
+ return helpers .quantization_offset (self .prior )
280
+ return None
245
281
246
282
@tf .Module .with_name_scope
247
283
def __call__ (self , bottleneck , training = True ):
@@ -282,7 +318,7 @@ def quantize(self, bottleneck):
282
318
The tensor is rounded to integer values potentially shifted by offsets (if
283
319
`self.quantization_offset is not None`). These offsets can depend on
284
320
`self.prior`. For instance, for a Gaussian distribution, when
285
- `self.non_integer_offset == True`, the returned values would be rounded
321
+ `self.offset_heuristic == True`, the returned values would be rounded
286
322
to the location of the mode of the distribution plus or minus an integer.
287
323
288
324
The gradient of this rounding operation is overridden with the identity
@@ -379,9 +415,7 @@ def get_config(self):
379
415
config = super ().get_config ()
380
416
config .update (
381
417
prior_shape = tuple (map (int , self .prior_shape )),
382
- # Since the prior is never passed when using the `SavedModel` protocol,
383
- # we can reuse this flag to indicate whether the offsets need to be
384
- # loaded from a variable.
385
- non_integer_offset = self .quantization_offset is not None ,
418
+ offset_heuristic = self .offset_heuristic ,
419
+ quantization_offset = self .quantization_offset is not None ,
386
420
)
387
421
return config
0 commit comments