Skip to content

Commit 7654443

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds no_variables option to entropy models.
PiperOrigin-RevId: 332352604 Change-Id: I1fd01aeeda87e6b27d22431e0a7c9467eb2e5fb1
1 parent ad75044 commit 7654443

File tree

3 files changed

+75
-38
lines changed

3 files changed

+75
-38
lines changed

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ContinuousEntropyModelBase(tf.Module, metaclass=abc.ABCMeta):
3939
@abc.abstractmethod
4040
def __init__(self, prior, coding_rank, compression=False,
4141
likelihood_bound=1e-9, tail_mass=2**-8,
42-
range_coder_precision=12):
42+
range_coder_precision=12, no_variables=False):
4343
"""Initializer.
4444
4545
Arguments:
@@ -60,6 +60,8 @@ def __init__(self, prior, coding_rank, compression=False,
6060
tail_mass: Float. Approximate probability mass which is range encoded with
6161
less precision, by using a Golomb-like code.
6262
range_coder_precision: Integer. Precision passed to the range coding op.
63+
no_variables: Boolean. If True, creates range coding tables as `Tensor`s
64+
rather than `Variable`s.
6365
6466
Raises:
6567
RuntimeError: when attempting to instantiate an entropy model with
@@ -77,6 +79,7 @@ def __init__(self, prior, coding_rank, compression=False,
7779
self._likelihood_bound = float(likelihood_bound)
7880
self._tail_mass = float(tail_mass)
7981
self._range_coder_precision = int(range_coder_precision)
82+
self._no_variables = bool(no_variables)
8083
if self.compression:
8184
self._build_tables(prior)
8285

@@ -103,17 +106,17 @@ def _check_compression(self):
103106
@property
104107
def cdf(self):
105108
self._check_compression()
106-
return tf.identity(self._cdf)
109+
return tf.convert_to_tensor(self._cdf)
107110

108111
@property
109112
def cdf_offset(self):
110113
self._check_compression()
111-
return tf.identity(self._cdf_offset)
114+
return tf.convert_to_tensor(self._cdf_offset)
112115

113116
@property
114117
def cdf_length(self):
115118
self._check_compression()
116-
return tf.identity(self._cdf_length)
119+
return tf.convert_to_tensor(self._cdf_length)
117120

118121
@property
119122
def dtype(self):
@@ -155,6 +158,11 @@ def range_coder_precision(self):
155158
"""Precision passed to range coding op."""
156159
return self._range_coder_precision
157160

161+
@property
162+
def no_variables(self):
163+
"""Whether range coding tables are created as `Tensor`s or `Variable`s."""
164+
return self._no_variables
165+
158166
@tf.custom_gradient
159167
def _quantize_no_offset(self, inputs):
160168
return tf.round(inputs), lambda x: x
@@ -247,11 +255,16 @@ def loop_body(args):
247255
cdf = tf.map_fn(
248256
loop_body, (pmf, pmf_length), dtype=tf.int32, name="pmf_to_cdf")
249257

250-
self._cdf = tf.Variable(cdf, trainable=False, name="cdf")
251-
self._cdf_offset = tf.Variable(
252-
cdf_offset, trainable=False, name="cdf_offset")
253-
self._cdf_length = tf.Variable(
254-
cdf_length, trainable=False, name="cdf_length")
258+
if self.no_variables:
259+
self._cdf = cdf
260+
self._cdf_offset = cdf_offset
261+
self._cdf_length = cdf_length
262+
else:
263+
self._cdf = tf.Variable(cdf, trainable=False, name="cdf")
264+
self._cdf_offset = tf.Variable(
265+
cdf_offset, trainable=False, name="cdf_offset")
266+
self._cdf_length = tf.Variable(
267+
cdf_length, trainable=False, name="cdf_length")
255268

256269
@abc.abstractmethod
257270
def get_config(self):
@@ -264,10 +277,10 @@ def get_config(self):
264277
NotImplementedError: on attempting to call this method on an entropy model
265278
with `compression=False`.
266279
"""
267-
if not self.compression:
280+
if self.no_variables or not self.compression:
268281
raise NotImplementedError(
269-
"Serializing entropy models with `compression=False` is currently "
270-
"not supported.")
282+
"Serializing entropy models with `compression=False` or "
283+
"`no_variables=True` is currently not supported.")
271284
return dict(
272285
dtype=self._dtype.name,
273286
prior_shape=self._prior_shape,
@@ -304,6 +317,7 @@ def from_config(cls, config):
304317
self._likelihood_bound = float(config["likelihood_bound"])
305318
self._tail_mass = float(config["tail_mass"])
306319
self._range_coder_precision = int(config["range_coder_precision"])
320+
self._no_variables = False
307321

308322
prior_size = functools.reduce(lambda x, y: x * y, self.prior_shape, 1)
309323
cdf_width = int(config["cdf_width"])

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
7575

7676
def __init__(self, prior, coding_rank, compression=False,
7777
likelihood_bound=1e-9, tail_mass=2**-8,
78-
range_coder_precision=12):
78+
range_coder_precision=12, no_variables=False):
7979
"""Initializer.
8080
8181
Arguments:
@@ -98,6 +98,8 @@ def __init__(self, prior, coding_rank, compression=False,
9898
tail_mass: Float. Approximate probability mass which is range encoded with
9999
less precision, by using a Golomb-like code.
100100
range_coder_precision: Integer. Precision passed to the range coding op.
101+
no_variables: Boolean. If True, creates range coding tables as `Tensor`s
102+
rather than `Variable`s.
101103
102104
Raises:
103105
RuntimeError: when attempting to instantiate an entropy model with
@@ -107,27 +109,38 @@ def __init__(self, prior, coding_rank, compression=False,
107109
raise ValueError(
108110
"`coding_rank` can't be smaller than batch rank of prior.")
109111
super().__init__(
110-
prior, coding_rank, compression=compression,
111-
likelihood_bound=likelihood_bound, tail_mass=tail_mass,
112-
range_coder_precision=range_coder_precision)
112+
prior=prior,
113+
coding_rank=coding_rank,
114+
compression=compression,
115+
likelihood_bound=likelihood_bound,
116+
tail_mass=tail_mass,
117+
range_coder_precision=range_coder_precision,
118+
no_variables=no_variables,
119+
)
113120

114121
quantization_offset = helpers.quantization_offset(prior)
115-
if self.compression:
116-
# Optimization: if the quantization offset is zero, we don't need to
117-
# subtract/add it when quantizing, and we don't need to serialize its
118-
# value. Note that this code will only work in eager mode.
119-
# TODO(jonycgn): Reconsider if this optimization is worth keeping once
120-
# the implementation is stable.
121-
if tf.executing_eagerly() and tf.reduce_all(
122-
tf.equal(quantization_offset, 0.)):
123-
quantization_offset = None
124-
else:
125-
quantization_offset = tf.broadcast_to(
126-
quantization_offset, self.prior_shape_tensor)
122+
# Optimization: if the quantization offset is zero, we don't need to
123+
# subtract/add it when quantizing, and we don't need to serialize its value.
124+
# Note that this code will only work in eager mode.
125+
# TODO(jonycgn): Reconsider if this optimization is worth keeping once the
126+
# implementation is stable.
127+
if tf.executing_eagerly() and tf.reduce_all(
128+
tf.equal(quantization_offset, 0.)):
129+
quantization_offset = None
130+
else:
131+
quantization_offset = tf.broadcast_to(
132+
quantization_offset, self.prior_shape_tensor)
133+
if self.compression and not self.no_variables:
127134
quantization_offset = tf.Variable(
128135
quantization_offset, trainable=False, name="quantization_offset")
129136
self._quantization_offset = quantization_offset
130137

138+
@property
139+
def quantization_offset(self):
140+
if self._quantization_offset is None:
141+
return None
142+
return tf.convert_to_tensor(self._quantization_offset)
143+
131144
def _compute_indexes(self, broadcast_shape):
132145
# TODO(jonycgn, ssjhv): Investigate broadcasting in range coding op.
133146
prior_size = functools.reduce(lambda x, y: x * y, self.prior_shape, 1)
@@ -187,7 +200,7 @@ def quantize(self, bottleneck):
187200
Returns:
188201
A `tf.Tensor` containing the quantized values.
189202
"""
190-
return self._quantize(bottleneck, self._quantization_offset)
203+
return self._quantize(bottleneck, self.quantization_offset)
191204

192205
@tf.Module.with_name_scope
193206
def compress(self, bottleneck):
@@ -220,8 +233,8 @@ def compress(self, bottleneck):
220233
:self.coding_rank - len(self.prior_shape)]
221234

222235
indexes = self._compute_indexes(broadcast_shape)
223-
if self._quantization_offset is not None:
224-
bottleneck -= self._quantization_offset
236+
if self.quantization_offset is not None:
237+
bottleneck -= self.quantization_offset
225238
symbols = tf.cast(tf.round(bottleneck), tf.int32)
226239
symbols = tf.reshape(symbols, tf.concat([[-1], coding_shape], 0))
227240

@@ -287,8 +300,8 @@ def loop_body(string):
287300

288301
symbols = tf.reshape(symbols, symbols_shape)
289302
outputs = tf.cast(symbols, self.dtype)
290-
if self._quantization_offset is not None:
291-
outputs += self._quantization_offset
303+
if self.quantization_offset is not None:
304+
outputs += self.quantization_offset
292305
return outputs
293306

294307
def get_config(self):

tensorflow_compression/python/entropy_models/continuous_indexed.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class ContinuousIndexedEntropyModel(continuous_base.ContinuousEntropyModelBase):
129129
def __init__(self, prior_fn, index_ranges, parameter_fns, coding_rank,
130130
compression=False, channel_axis=-1, dtype=tf.float32,
131131
likelihood_bound=1e-9, tail_mass=2**-8,
132-
range_coder_precision=12):
132+
range_coder_precision=12, no_variables=False):
133133
"""Initializer.
134134
135135
Arguments:
@@ -170,6 +170,8 @@ def __init__(self, prior_fn, index_ranges, parameter_fns, coding_rank,
170170
tail_mass: Float. Approximate probability mass which is range encoded with
171171
less precision, by using a Golomb-like code.
172172
range_coder_precision: Integer. Precision passed to the range coding op.
173+
no_variables: Boolean. If True, creates range coding tables as `Tensor`s
174+
rather than `Variable`s.
173175
174176
Raises:
175177
RuntimeError: when attempting to instantiate an entropy model with
@@ -204,9 +206,14 @@ def __init__(self, prior_fn, index_ranges, parameter_fns, coding_rank,
204206
prior = self.prior_fn(**parameters) # pylint:disable=not-callable
205207

206208
super().__init__(
207-
prior, coding_rank, compression=compression,
208-
likelihood_bound=likelihood_bound, tail_mass=tail_mass,
209-
range_coder_precision=range_coder_precision)
209+
prior=prior,
210+
coding_rank=coding_rank,
211+
compression=compression,
212+
likelihood_bound=likelihood_bound,
213+
tail_mass=tail_mass,
214+
range_coder_precision=range_coder_precision,
215+
no_variables=no_variables,
216+
)
210217

211218
@property
212219
def index_ranges(self):
@@ -433,7 +440,7 @@ class LocationScaleIndexedEntropyModel(ContinuousIndexedEntropyModel):
433440

434441
def __init__(self, prior_fn, num_scales, scale_fn, coding_rank,
435442
compression=False, dtype=tf.float32, likelihood_bound=1e-9,
436-
tail_mass=2**-8, range_coder_precision=12):
443+
tail_mass=2**-8, range_coder_precision=12, no_variables=False):
437444
"""Initializer.
438445
439446
Arguments:
@@ -464,6 +471,8 @@ def __init__(self, prior_fn, num_scales, scale_fn, coding_rank,
464471
tail_mass: Float. Approximate probability mass which is range encoded with
465472
less precision, by using a Golomb-like code.
466473
range_coder_precision: Integer. Precision passed to the range coding op.
474+
no_variables: Boolean. If True, creates range coding tables as `Tensor`s
475+
rather than `Variable`s.
467476
"""
468477
num_scales = int(num_scales)
469478
super().__init__(
@@ -479,6 +488,7 @@ def __init__(self, prior_fn, num_scales, scale_fn, coding_rank,
479488
likelihood_bound=likelihood_bound,
480489
tail_mass=tail_mass,
481490
range_coder_precision=range_coder_precision,
491+
no_variables=no_variables,
482492
)
483493

484494
@tf.Module.with_name_scope

0 commit comments

Comments
 (0)