@@ -59,16 +59,16 @@ class ContinuousIndexedEntropyModel(continuous_base.ContinuousEntropyModelBase):
59
59
each bottleneck tensor element, it selects the appropriate scalar
60
60
distribution.
61
61
62
- The `indexes` tensor must contain only integer values (but may have
63
- floating-point type for purposes of backpropagation) in a pre-specified range.
64
- If `index_ranges` is a single integer, the index values must be in the range
65
- `[0, index_ranges)` and `indexes` must have the same shape as the bottleneck
66
- tensor. This only allows a one-dimensional conditional dependency. To make the
67
- distribution conditional on `n`-dimensional indexes, `index_ranges` must be
68
- specified as an iterable of `n` integers. Then, `indexes` must have the same
62
+ The `indexes` tensor must contain only integer values in a pre-specified range
63
+ (but may have floating-point type for purposes of backpropagation). To make
64
+ the distribution conditional on `n`-dimensional indexes, `index_ranges` must
65
+ be specified as an iterable of `n` integers. `indexes` must have the same
69
66
shape as the bottleneck tensor with an additional channel dimension of length
70
67
`n`. The position of the channel dimension is given by `channel_axis`. The
71
- index values in the `n`th channel must be in the range `[0, index_ranges[n])`.
68
+ index values in the `k`th channel must be in the range `[0, index_ranges[k])`.
69
+ If `index_ranges` has only one element (i.e. `n == 1`), `channel_axis` may be
70
+ `None`. In that case, the additional channel dimension is omitted, and the
71
+ `indexes` tensor must have the same shape as the bottleneck tensor.
72
72
73
73
The implied distribution for the bottleneck tensor is determined as:
74
74
```
@@ -89,12 +89,13 @@ class ContinuousIndexedEntropyModel(continuous_base.ContinuousEntropyModelBase):
89
89
```
90
90
tfc.ContinuousIndexedEntropyModel(
91
91
prior_fn=tfc.NoisyNormal,
92
- index_ranges=64 ,
92
+ index_ranges=(64,) ,
93
93
parameter_fns=dict(
94
94
loc=lambda _: 0.,
95
95
scale=lambda i: tf.exp(i / 8 - 5),
96
96
),
97
97
coding_rank=1,
98
+ channel_axis=None,
98
99
)
99
100
```
100
101
Then, each element of `indexes` in the range `[0, 64)` would indicate that the
@@ -149,12 +150,10 @@ def __init__(self,
149
150
since this is the marginal distribution for bottleneck dimensions that
150
151
are constant. The callable will receive keyword arguments as determined
151
152
by `parameter_fns`.
152
- index_ranges: Integer or iterable of integers. If a single integer,
153
- `indexes` must have the same shape as `bottleneck`, and `channel_axis`
154
- is ignored. Its values must be in the range `[0, index_ranges)`. If an
155
- iterable of integers, `indexes` must have an additional dimension at
156
- position `channel_axis`, and the values of the `n`th channel must be in
157
- the range `[0, index_ranges[n])`.
153
+ index_ranges: Iterable of integers. `indexes` must have the same shape as
154
+ the bottleneck tensor, with an additional dimension at position
155
+ `channel_axis`. The values of the `k`th channel must be in the range
156
+ `[0, index_ranges[k])`.
158
157
parameter_fns: Dict of strings to callables. Functions mapping `indexes`
159
158
to each distribution parameter. For each item, `indexes` is passed to
160
159
the callable, and the string key and return value make up one keyword
@@ -167,9 +166,10 @@ def __init__(self,
167
166
assumes eager mode (throws an error if in graph mode or inside a
168
167
`tf.function` call). If set to `False`, these two methods will not be
169
168
accessible.
170
- channel_axis: Integer. For iterable `index_ranges`, determines the
171
- position of the channel axis in `indexes`. Defaults to the last
172
- dimension.
169
+ channel_axis: Integer or `None`. Determines the position of the channel
170
+ axis in `indexes`. Defaults to the last dimension. If set to `None`,
171
+ the index tensor is expected to have the same shape as the bottleneck
172
+ tensor (only allowed when `index_ranges` has length 1).
173
173
dtype: `tf.dtypes.DType`. The data type of all floating-point
174
174
computations carried out in this class.
175
175
laplace_tail_mass: Float. If positive, will augment the prior with a
@@ -187,19 +187,24 @@ def __init__(self,
187
187
`compression=True` and not in eager execution mode.
188
188
"""
189
189
if coding_rank <= 0 :
190
- raise ValueError ("` coding_rank` must be larger than 0." )
190
+ raise ValueError ("coding_rank must be larger than 0." )
191
191
if not callable (prior_fn ):
192
- raise TypeError ("` prior_fn` must be a class or factory function." )
192
+ raise TypeError ("prior_fn must be a class or factory function." )
193
193
for name , fn in parameter_fns .items ():
194
194
if not isinstance (name , str ):
195
- raise TypeError ("` parameter_fns` must have string keys." )
195
+ raise TypeError ("parameter_fns must have string keys." )
196
196
if not callable (fn ):
197
- raise TypeError ("`parameter_fns['{}']` must be callable." .format (name ))
198
-
199
- prior = self ._make_range_coding_prior (prior_fn , index_ranges , parameter_fns ,
200
- channel_axis , dtype )
197
+ raise TypeError (f"parameter_fns['{ name } '] must be callable." )
198
+ self ._index_ranges = tuple (int (r ) for r in index_ranges )
199
+ if not self .index_ranges :
200
+ raise ValueError ("index_ranges must have at least one element." )
201
+ self ._channel_axis = None if channel_axis is None else int (channel_axis )
202
+ if self .channel_axis is None and len (self .index_ranges ) > 1 :
203
+ raise ValueError ("channel_axis can't be None for len(index_ranges) > 1." )
204
+ self ._prior_fn = prior_fn
205
+ self ._parameter_fns = dict (parameter_fns )
201
206
super ().__init__ (
202
- prior = prior ,
207
+ prior = self . _make_range_coding_prior ( self . index_ranges , dtype ) ,
203
208
coding_rank = coding_rank ,
204
209
compression = compression ,
205
210
laplace_tail_mass = laplace_tail_mass ,
@@ -208,14 +213,6 @@ def __init__(self,
208
213
range_coder_precision = range_coder_precision ,
209
214
no_variables = no_variables
210
215
)
211
- self ._channel_axis = int (channel_axis )
212
- self ._prior_fn = prior_fn
213
- # TODO(relational, jonycgn): Do we need special casing for int index_ranges?
214
- try :
215
- self ._index_ranges = int (index_ranges )
216
- except TypeError :
217
- self ._index_ranges = tuple (int (r ) for r in index_ranges ) # pytype:disable=attribute-error
218
- self ._parameter_fns = dict (parameter_fns )
219
216
220
217
@property
221
218
def index_ranges (self ):
@@ -242,23 +239,23 @@ def _make_prior(self, indexes, dtype=None):
242
239
parameters = {k : f (indexes ) for k , f in self .parameter_fns .items ()}
243
240
return self .prior_fn (** parameters )
244
241
245
- def _make_range_coding_prior (self , prior_fn , index_ranges , parameter_fns ,
246
- channel_axis , dtype ):
247
- del self # Method does not depend on instance state.
242
+ def _make_range_coding_prior (self , index_ranges , dtype ):
243
+ """Instantiates the range coding prior."""
248
244
dtype = tf .as_dtype (dtype )
249
- if isinstance (index_ranges , int ):
250
- indexes = tf .range (index_ranges , dtype = dtype )
245
+ if self .channel_axis is None :
246
+ index_range , = index_ranges
247
+ indexes = tf .range (index_range , dtype = dtype )
251
248
else :
252
249
indexes = [tf .range (r , dtype = dtype ) for r in index_ranges ]
253
250
indexes = tf .meshgrid (* indexes , indexing = "ij" )
254
- indexes = tf .stack (indexes , axis = channel_axis )
255
- parameters = {k : f (indexes ) for k , f in parameter_fns .items ()}
256
- return prior_fn (** parameters )
251
+ indexes = tf .stack (indexes , axis = self .channel_axis )
252
+ return self ._make_prior (indexes , dtype = dtype )
257
253
258
254
def _normalize_indexes (self , indexes ):
259
255
indexes = math_ops .lower_bound (indexes , 0 )
260
- if isinstance (self .index_ranges , int ):
261
- bounds = self .index_ranges - 1
256
+ if self .channel_axis is None :
257
+ index_range , = self .index_ranges
258
+ bounds = index_range - 1
262
259
else :
263
260
axes = [1 ] * indexes .shape .rank
264
261
axes [self .channel_axis ] = len (self .index_ranges )
@@ -268,7 +265,7 @@ def _normalize_indexes(self, indexes):
268
265
269
266
def _flatten_indexes (self , indexes ):
270
267
indexes = tf .cast (indexes , tf .int32 )
271
- if isinstance ( self .index_ranges , int ) :
268
+ if self .channel_axis is None :
272
269
return indexes
273
270
else :
274
271
strides = tf .math .cumprod (self .index_ranges , exclusive = True , reverse = True )
@@ -509,13 +506,14 @@ def __init__(self, prior_fn, num_scales, scale_fn, coding_rank,
509
506
num_scales = int (num_scales )
510
507
super ().__init__ (
511
508
prior_fn = prior_fn ,
512
- index_ranges = num_scales ,
509
+ index_ranges = ( num_scales ,) ,
513
510
parameter_fns = dict (
514
511
loc = lambda _ : 0. ,
515
512
scale = scale_fn ,
516
513
),
517
514
coding_rank = coding_rank ,
518
515
compression = compression ,
516
+ channel_axis = None ,
519
517
dtype = dtype ,
520
518
tail_mass = tail_mass ,
521
519
range_coder_precision = range_coder_precision ,
0 commit comments