Skip to content

Commit 0c0e40f

Browse files
Johannes Ballécopybara-github
authored andcommitted
Makes index_ranges always iterable.
We now allow an omitted channel dimension via channel_axis=None instead. PiperOrigin-RevId: 355495759 Change-Id: I52281d18d726384341c8a818c4b49b0d4196cc96
1 parent 0028f42 commit 0c0e40f

File tree

4 files changed

+56
-82
lines changed

4 files changed

+56
-82
lines changed

tensorflow_compression/python/entropy_models/continuous_indexed.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ class ContinuousIndexedEntropyModel(continuous_base.ContinuousEntropyModelBase):
5959
each bottleneck tensor element, it selects the appropriate scalar
6060
distribution.
6161
62-
The `indexes` tensor must contain only integer values (but may have
63-
floating-point type for purposes of backpropagation) in a pre-specified range.
64-
If `index_ranges` is a single integer, the index values must be in the range
65-
`[0, index_ranges)` and `indexes` must have the same shape as the bottleneck
66-
tensor. This only allows a one-dimensional conditional dependency. To make the
67-
distribution conditional on `n`-dimensional indexes, `index_ranges` must be
68-
specified as an iterable of `n` integers. Then, `indexes` must have the same
62+
The `indexes` tensor must contain only integer values in a pre-specified range
63+
(but may have floating-point type for purposes of backpropagation). To make
64+
the distribution conditional on `n`-dimensional indexes, `index_ranges` must
65+
be specified as an iterable of `n` integers. `indexes` must have the same
6966
shape as the bottleneck tensor with an additional channel dimension of length
7067
`n`. The position of the channel dimension is given by `channel_axis`. The
71-
index values in the `n`th channel must be in the range `[0, index_ranges[n])`.
68+
index values in the `k`th channel must be in the range `[0, index_ranges[k])`.
69+
If `index_ranges` has only one element (i.e. `n == 1`), `channel_axis` may be
70+
`None`. In that case, the additional channel dimension is omitted, and the
71+
`indexes` tensor must have the same shape as the bottleneck tensor.
7272
7373
The implied distribution for the bottleneck tensor is determined as:
7474
```
@@ -89,12 +89,13 @@ class ContinuousIndexedEntropyModel(continuous_base.ContinuousEntropyModelBase):
8989
```
9090
tfc.ContinuousIndexedEntropyModel(
9191
prior_fn=tfc.NoisyNormal,
92-
index_ranges=64,
92+
index_ranges=(64,),
9393
parameter_fns=dict(
9494
loc=lambda _: 0.,
9595
scale=lambda i: tf.exp(i / 8 - 5),
9696
),
9797
coding_rank=1,
98+
channel_axis=None,
9899
)
99100
```
100101
Then, each element of `indexes` in the range `[0, 64)` would indicate that the
@@ -149,12 +150,10 @@ def __init__(self,
149150
since this is the marginal distribution for bottleneck dimensions that
150151
are constant. The callable will receive keyword arguments as determined
151152
by `parameter_fns`.
152-
index_ranges: Integer or iterable of integers. If a single integer,
153-
`indexes` must have the same shape as `bottleneck`, and `channel_axis`
154-
is ignored. Its values must be in the range `[0, index_ranges)`. If an
155-
iterable of integers, `indexes` must have an additional dimension at
156-
position `channel_axis`, and the values of the `n`th channel must be in
157-
the range `[0, index_ranges[n])`.
153+
index_ranges: Iterable of integers. `indexes` must have the same shape as
154+
the bottleneck tensor, with an additional dimension at position
155+
`channel_axis`. The values of the `k`th channel must be in the range
156+
`[0, index_ranges[k])`.
158157
parameter_fns: Dict of strings to callables. Functions mapping `indexes`
159158
to each distribution parameter. For each item, `indexes` is passed to
160159
the callable, and the string key and return value make up one keyword
@@ -167,9 +166,10 @@ def __init__(self,
167166
assumes eager mode (throws an error if in graph mode or inside a
168167
`tf.function` call). If set to `False`, these two methods will not be
169168
accessible.
170-
channel_axis: Integer. For iterable `index_ranges`, determines the
171-
position of the channel axis in `indexes`. Defaults to the last
172-
dimension.
169+
channel_axis: Integer or `None`. Determines the position of the channel
170+
axis in `indexes`. Defaults to the last dimension. If set to `None`,
171+
the index tensor is expected to have the same shape as the bottleneck
172+
tensor (only allowed when `index_ranges` has length 1).
173173
dtype: `tf.dtypes.DType`. The data type of all floating-point
174174
computations carried out in this class.
175175
laplace_tail_mass: Float. If positive, will augment the prior with a
@@ -187,19 +187,24 @@ def __init__(self,
187187
`compression=True` and not in eager execution mode.
188188
"""
189189
if coding_rank <= 0:
190-
raise ValueError("`coding_rank` must be larger than 0.")
190+
raise ValueError("coding_rank must be larger than 0.")
191191
if not callable(prior_fn):
192-
raise TypeError("`prior_fn` must be a class or factory function.")
192+
raise TypeError("prior_fn must be a class or factory function.")
193193
for name, fn in parameter_fns.items():
194194
if not isinstance(name, str):
195-
raise TypeError("`parameter_fns` must have string keys.")
195+
raise TypeError("parameter_fns must have string keys.")
196196
if not callable(fn):
197-
raise TypeError("`parameter_fns['{}']` must be callable.".format(name))
198-
199-
prior = self._make_range_coding_prior(prior_fn, index_ranges, parameter_fns,
200-
channel_axis, dtype)
197+
raise TypeError(f"parameter_fns['{name}'] must be callable.")
198+
self._index_ranges = tuple(int(r) for r in index_ranges)
199+
if not self.index_ranges:
200+
raise ValueError("index_ranges must have at least one element.")
201+
self._channel_axis = None if channel_axis is None else int(channel_axis)
202+
if self.channel_axis is None and len(self.index_ranges) > 1:
203+
raise ValueError("channel_axis can't be None for len(index_ranges) > 1.")
204+
self._prior_fn = prior_fn
205+
self._parameter_fns = dict(parameter_fns)
201206
super().__init__(
202-
prior=prior,
207+
prior=self._make_range_coding_prior(self.index_ranges, dtype),
203208
coding_rank=coding_rank,
204209
compression=compression,
205210
laplace_tail_mass=laplace_tail_mass,
@@ -208,14 +213,6 @@ def __init__(self,
208213
range_coder_precision=range_coder_precision,
209214
no_variables=no_variables
210215
)
211-
self._channel_axis = int(channel_axis)
212-
self._prior_fn = prior_fn
213-
# TODO(relational, jonycgn): Do we need special casing for int index_ranges?
214-
try:
215-
self._index_ranges = int(index_ranges)
216-
except TypeError:
217-
self._index_ranges = tuple(int(r) for r in index_ranges) # pytype:disable=attribute-error
218-
self._parameter_fns = dict(parameter_fns)
219216

220217
@property
221218
def index_ranges(self):
@@ -242,23 +239,23 @@ def _make_prior(self, indexes, dtype=None):
242239
parameters = {k: f(indexes) for k, f in self.parameter_fns.items()}
243240
return self.prior_fn(**parameters)
244241

245-
def _make_range_coding_prior(self, prior_fn, index_ranges, parameter_fns,
246-
channel_axis, dtype):
247-
del self # Method does not depend on instance state.
242+
def _make_range_coding_prior(self, index_ranges, dtype):
243+
"""Instantiates the range coding prior."""
248244
dtype = tf.as_dtype(dtype)
249-
if isinstance(index_ranges, int):
250-
indexes = tf.range(index_ranges, dtype=dtype)
245+
if self.channel_axis is None:
246+
index_range, = index_ranges
247+
indexes = tf.range(index_range, dtype=dtype)
251248
else:
252249
indexes = [tf.range(r, dtype=dtype) for r in index_ranges]
253250
indexes = tf.meshgrid(*indexes, indexing="ij")
254-
indexes = tf.stack(indexes, axis=channel_axis)
255-
parameters = {k: f(indexes) for k, f in parameter_fns.items()}
256-
return prior_fn(**parameters)
251+
indexes = tf.stack(indexes, axis=self.channel_axis)
252+
return self._make_prior(indexes, dtype=dtype)
257253

258254
def _normalize_indexes(self, indexes):
259255
indexes = math_ops.lower_bound(indexes, 0)
260-
if isinstance(self.index_ranges, int):
261-
bounds = self.index_ranges - 1
256+
if self.channel_axis is None:
257+
index_range, = self.index_ranges
258+
bounds = index_range - 1
262259
else:
263260
axes = [1] * indexes.shape.rank
264261
axes[self.channel_axis] = len(self.index_ranges)
@@ -268,7 +265,7 @@ def _normalize_indexes(self, indexes):
268265

269266
def _flatten_indexes(self, indexes):
270267
indexes = tf.cast(indexes, tf.int32)
271-
if isinstance(self.index_ranges, int):
268+
if self.channel_axis is None:
272269
return indexes
273270
else:
274271
strides = tf.math.cumprod(self.index_ranges, exclusive=True, reverse=True)
@@ -509,13 +506,14 @@ def __init__(self, prior_fn, num_scales, scale_fn, coding_rank,
509506
num_scales = int(num_scales)
510507
super().__init__(
511508
prior_fn=prior_fn,
512-
index_ranges=num_scales,
509+
index_ranges=(num_scales,),
513510
parameter_fns=dict(
514511
loc=lambda _: 0.,
515512
scale=scale_fn,
516513
),
517514
coding_rank=coding_rank,
518515
compression=compression,
516+
channel_axis=None,
519517
dtype=dtype,
520518
tail_mass=tail_mass,
521519
range_coder_precision=range_coder_precision,

tensorflow_compression/python/entropy_models/continuous_indexed_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ class ContinuousIndexedEntropyModelTest(tf.test.TestCase):
2626

2727
def test_can_instantiate_one_dimensional(self):
2828
em = continuous_indexed.ContinuousIndexedEntropyModel(
29-
uniform_noise.NoisyNormal, 64,
29+
uniform_noise.NoisyNormal, (64,),
3030
dict(loc=lambda _: 0, scale=lambda i: tf.exp(i / 8 - 5)), 1,
31-
compression=True)
31+
compression=True, channel_axis=None)
3232
self.assertIsInstance(em.prior, uniform_noise.NoisyNormal)
3333
self.assertEqual(em.coding_rank, 1)
3434
self.assertEqual(em.tail_mass, 2**-8)

tensorflow_compression/python/entropy_models/universal.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class UniversalBatchedEntropyModel(
6565
continuous_batched.ContinuousBatchedEntropyModel):
6666
"""Batched entropy model model which implements Universal Quantization.
6767
68-
In contrast to the base class, which uses roundinig for quantization, here
68+
In contrast to the base class, which uses rounding for quantization, here
6969
"quantization" is performed additive uniform noise, which is implemented with
7070
Universal Quantization.
7171
@@ -232,7 +232,7 @@ class UniversalIndexedEntropyModel(
232232
continuous_indexed.ContinuousIndexedEntropyModel):
233233
"""Indexed entropy model model which implements Universal Quantization.
234234
235-
In contrast to the base class, which uses roundinig for quantization, here
235+
In contrast to the base class, which uses rounding for quantization, here
236236
"quantization" is performed additive uniform noise, which is implemented with
237237
Universal Quantization.
238238
@@ -268,10 +268,9 @@ def __init__(self,
268268
since this is the marginal distribution for bottleneck dimensions that
269269
are constant. The callable will receive keyword arguments as determined
270270
by `parameter_fns`.
271-
index_ranges: Iterable of integers. If (non-empty), compared to
272-
`bottleneck`, `indexes` in __call__() must have an additional dimension
273-
at position `channel_axis`, and the values of the `n`th channel must be
274-
in the range `[0, index_ranges[n])`.
271+
index_ranges: Iterable of integers. Compared to `bottleneck`, `indexes`
272+
in `__call__()` must have an additional trailing dimension, and the
273+
values of the `k`th channel must be in the range `[0, index_ranges[k])`.
275274
parameter_fns: Dict of strings to callables. Functions mapping `indexes`
276275
to each distribution parameter. For each item, `indexes` is passed to
277276
the callable, and the string key and return value make up one keyword
@@ -302,10 +301,6 @@ def __init__(self,
302301
RuntimeError: when attempting to instantiate an entropy model with
303302
`compression=True` and not in eager execution mode.
304303
"""
305-
if isinstance(index_ranges, int):
306-
raise ValueError(
307-
"An iterable of integers is only supported for `index_ranges`.")
308-
309304
# Add extra indexes for noise levels.
310305
index_ranges_with_offsets = tuple([num_noise_levels] +
311306
[int(r) for r in index_ranges])
@@ -342,7 +337,7 @@ def index_ranges_without_offsets(self):
342337
return _index_ranges_without_offsets(self.index_ranges)
343338

344339
def _normalize_indexes(self, indexes):
345-
"""See base class."""
340+
"""See base class."""
346341
num_indexes = indexes.shape[-1] # Last dim of `indexes` should be static.
347342
if num_indexes == len(self.index_ranges):
348343
# Indexes have offsets.
@@ -364,20 +359,10 @@ def _offset_from_indexes(self, indexes_with_offsets):
364359
offset_indexes, self._num_noise_levels, dtype=self.dtype)
365360
return offset
366361

367-
def _make_range_coding_prior(self, prior_fn, index_ranges_with_offsets,
368-
parameter_fns, channel_axis, dtype):
369-
"""Computes the range coding prior."""
370-
del self # Method does not depend on instance state.
371-
dtype = tf.as_dtype(dtype)
372-
index_ranges_without_offsets = _index_ranges_without_offsets(
373-
index_ranges_with_offsets)
374-
indexes = [
375-
tf.range(r, dtype=dtype) for r in index_ranges_without_offsets
376-
]
377-
indexes = tf.meshgrid(*indexes, indexing="ij")
378-
indexes = tf.stack(indexes, axis=channel_axis)
379-
parameters = {k: f(indexes) for k, f in parameter_fns.items()}
380-
return prior_fn(**parameters)
362+
def _make_range_coding_prior(self, index_ranges, dtype):
363+
"""Instantiates the range coding prior."""
364+
return super()._make_range_coding_prior(
365+
_index_ranges_without_offsets(index_ranges), dtype)
381366

382367
def _offset_from_prior(self, prior):
383368
return _range_coding_offsets(self._num_noise_levels, self.prior_shape,

tensorflow_compression/python/entropy_models/universal_test.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,6 @@ def setUp(self):
149149
super().setUp()
150150
tf.random.set_seed(1234)
151151

152-
def test_cannot_instantiate_one_dimensional(self):
153-
with self.assertRaises(ValueError):
154-
universal.UniversalIndexedEntropyModel(
155-
uniform_noise.NoisyNormal,
156-
coding_rank=1,
157-
index_ranges=64,
158-
parameter_fns=dict(
159-
loc=lambda _: 0, scale=lambda i: tf.exp(i / 8 - 5)))
160-
161152
def test_can_instantiate_n_dimensional(self):
162153
em = universal.UniversalIndexedEntropyModel(
163154
uniform_noise.NoisyLogisticMixture,

0 commit comments

Comments
 (0)