@@ -43,7 +43,6 @@ def __init__(self,
43
43
stateless = False ,
44
44
expected_grads = False ,
45
45
tail_mass = 2 ** - 8 ,
46
- range_coder_precision = 12 ,
47
46
dtype = None ,
48
47
laplace_tail_mass = 0 ):
49
48
"""Initializes the instance.
@@ -65,12 +64,11 @@ def __init__(self,
65
64
`stateless=True` is implied and the provided value is ignored.
66
65
expected_grads: If True, will use analytical expected gradients during
67
66
backpropagation w.r.t. additive uniform noise.
68
- tail_mass: Float. Approximate probability mass which is range encoded with
69
- less precision, by using a Golomb-like code.
70
- range_coder_precision: Integer. Precision passed to the range coding op.
67
+ tail_mass: Float. Approximate probability mass which is encoded using an
68
+ Elias gamma code embedded into the range coder.
71
69
dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of
72
70
prior, decompressed values).
73
- laplace_tail_mass: Float. If positive , will augment the prior with a
71
+ laplace_tail_mass: Float. If non-zero , will augment the prior with a
74
72
Laplace mixture for training stability. (experimental)
75
73
"""
76
74
super ().__init__ ()
@@ -80,11 +78,18 @@ def __init__(self,
80
78
self ._stateless = bool (stateless )
81
79
self ._expected_grads = bool (expected_grads )
82
80
self ._tail_mass = float (tail_mass )
83
- self ._range_coder_precision = int (range_coder_precision )
84
81
self ._dtype = tf .as_dtype (dtype )
85
82
self ._laplace_tail_mass = float (laplace_tail_mass )
83
+
84
+ if self .coding_rank < 0 :
85
+ raise ValueError ("`coding_rank` must be at least 0." )
86
+ if not 0 < self .tail_mass < 1 :
87
+ raise ValueError ("`tail_mass` must be between 0 and 1." )
88
+ if not 0 <= self .laplace_tail_mass < 1 :
89
+ raise ValueError ("`laplace_tail_mass` must be between 0 and 1." )
90
+
86
91
with self .name_scope :
87
- self ._laplace_prior = (tfp .distributions .Laplace (loc = 0.0 , scale = 1.0 )
92
+ self ._laplace_prior = (tfp .distributions .Laplace (loc = 0. , scale = 1. )
88
93
if laplace_tail_mass else None )
89
94
90
95
def _check_compression (self ):
@@ -117,11 +122,6 @@ def cdf_offset(self):
117
122
self ._check_compression ()
118
123
return tf .convert_to_tensor (self ._cdf_offset )
119
124
120
- @property
121
- def cdf_length (self ):
122
- self ._check_compression ()
123
- return tf .convert_to_tensor (self ._cdf_length )
124
-
125
125
@property
126
126
def dtype (self ):
127
127
"""Data type of this entropy model."""
@@ -159,16 +159,16 @@ def tail_mass(self):
159
159
160
160
@property
161
161
def range_coder_precision (self ):
162
- """Precision passed to range coding op."""
163
- return self ._range_coder_precision
162
+ """Precision used in range coding op."""
163
+ return - self .cdf [ 0 ]
164
164
165
- def _init_compression (self , cdf , cdf_offset , cdf_length , cdf_shape ):
165
+ def _init_compression (self , cdf , cdf_offset , cdf_shapes ):
166
166
"""Sets up this entropy model for using the range coder.
167
167
168
- This is done by storing `cdf`, `cdf_offset`, and `cdf_length` in
169
- `tf.Variable`s (`stateless=False`) or `tf.Tensor`s (`stateless=True`) as
170
- attributes of this object, or creating the variables as placeholders if
171
- `cdf_shape` is provided.
168
+ This is done by storing `cdf` and `cdf_offset` in `tf.Variable`s
169
+ (`stateless=False`) or `tf.Tensor`s (`stateless=True`) as attributes of this
170
+ object, or creating the variables as placeholders if `cdf_shapes` is
171
+ provided.
172
172
173
173
The reason for pre-computing the tables is that they must not be
174
174
re-generated independently on the sending and receiving side, since small
@@ -184,41 +184,33 @@ def _init_compression(self, cdf, cdf_offset, cdf_length, cdf_shape):
184
184
Args:
185
185
cdf: CDF table for range coder.
186
186
cdf_offset: CDF offset table for range coder.
187
- cdf_length: CDF length table for range coder.
188
- cdf_shape: Iterable of 2 integers, the shape of `cdf`. Mutually exclusive
189
- with the other three arguments. If provided, creates placeholder values
190
- for them.
187
+ cdf_shapes: Iterable of integers, the shapes of `cdf` and `cdf_offset`.
188
+ Mutually exclusive with the other two arguments. If provided, creates
189
+ placeholder values for them.
191
190
"""
192
- if not ((cdf is None ) == (cdf_offset is None ) == (cdf_length is None ) ==
193
- (cdf_shape is not None )):
191
+ if not (cdf is None ) == (cdf_offset is None ) == (cdf_shapes is not None ):
194
192
raise ValueError (
195
- "Either all of `cdf`, `cdf_offset`, and `cdf_length`; or `cdf_shape` "
196
- "must be provided." )
197
- if cdf_shape is not None :
193
+ "Either both `cdf` and `cdf_offset`, or `cdf_shapes` must be "
194
+ "provided." )
195
+ if cdf_shapes is not None :
198
196
if self .stateless :
199
- raise ValueError ("With `stateless=True`, can't provide `cdf_shape`." )
200
- cdf_shape = tuple (map (int , cdf_shape ))
201
- if len (cdf_shape ) != 2 :
202
- raise ValueError ("`cdf_shape` must consist of 2 integers." )
203
- zeros = tf .zeros (cdf_shape , dtype = tf .int32 )
204
- cdf = zeros
205
- cdf_offset = zeros [:, 0 ]
206
- cdf_length = zeros [:, 0 ]
197
+ raise ValueError ("With `stateless=True`, can't provide `cdf_shapes`." )
198
+ cdf_shapes = tuple (map (int , cdf_shapes ))
199
+ if len (cdf_shapes ) != 2 :
200
+ raise ValueError ("`cdf_shapes` must have two elements." )
201
+ cdf = tf .zeros (cdf_shapes [:1 ], dtype = tf .int32 )
202
+ cdf_offset = tf .zeros (cdf_shapes [1 :], dtype = tf .int32 )
207
203
if self .stateless :
208
204
self ._cdf = tf .convert_to_tensor (cdf , dtype = tf .int32 , name = "cdf" )
209
205
self ._cdf_offset = tf .convert_to_tensor (
210
206
cdf_offset , dtype = tf .int32 , name = "cdf_offset" )
211
- self ._cdf_length = tf .convert_to_tensor (
212
- cdf_length , dtype = tf .int32 , name = "cdf_length" )
213
207
else :
214
208
self ._cdf = tf .Variable (
215
209
cdf , dtype = tf .int32 , trainable = False , name = "cdf" )
216
210
self ._cdf_offset = tf .Variable (
217
211
cdf_offset , dtype = tf .int32 , trainable = False , name = "cdf_offset" )
218
- self ._cdf_length = tf .Variable (
219
- cdf_length , dtype = tf .int32 , trainable = False , name = "cdf_length" )
220
212
221
- def _build_tables (self , prior , offset = None , context_shape = None ):
213
+ def _build_tables (self , prior , precision , offset = None ):
222
214
"""Computes integer-valued probability tables used by the range coder.
223
215
224
216
These tables must not be re-generated independently on the sending and
@@ -233,18 +225,16 @@ def _build_tables(self, prior, offset=None, context_shape=None):
233
225
234
226
Args:
235
227
prior: The `tfp.distributions.Distribution` object (see initializer).
236
- offset: Quantization offsets to use for sampling prior probabilities.
237
- Defaults to 0.
238
- context_shape: Shape of innermost dimensions to evaluate the prior on.
239
- Defaults to and must include `prior.batch_shape`.
228
+ precision: Integer. Precision for range coder.
229
+ offset: None or float tensor between -.5 and +.5. Sub-integer quantization
230
+ offsets to use for sampling prior probabilities. Defaults to 0.
240
231
241
232
Returns:
242
233
CDF table, CDF offsets, CDF lengths.
243
234
"""
235
+ precision = int (precision )
244
236
if offset is None :
245
237
offset = 0.
246
- if context_shape is None :
247
- context_shape = tf .TensorShape (prior .batch_shape )
248
238
# Subclasses should have already caught this, but better be safe.
249
239
assert not prior .event_shape .rank
250
240
@@ -269,38 +259,38 @@ def _build_tables(self, prior, offset=None, context_shape=None):
269
259
"Consider priors with smaller variance, or increasing `tail_mass` "
270
260
"parameter." , int (max_length ))
271
261
samples = tf .range (tf .cast (max_length , self .dtype ), dtype = self .dtype )
272
- samples = tf .reshape (samples , [- 1 ] + context_shape .rank * [1 ])
262
+ samples = tf .reshape (samples , [- 1 ] + pmf_length . shape .rank * [1 ])
273
263
samples += pmf_start
274
264
pmf = prior .prob (samples )
265
+ pmf_shape = tf .shape (pmf )[1 :]
266
+ num_pmfs = tf .reduce_prod (pmf_shape )
275
267
276
268
# Collapse batch dimensions of distribution.
277
- pmf = tf .reshape (pmf , [max_length , - 1 ])
269
+ pmf = tf .reshape (pmf , [max_length , num_pmfs ])
278
270
pmf = tf .transpose (pmf )
279
271
280
- context_shape = tf .constant (context_shape .as_list (), dtype = tf .int32 )
281
- pmf_length = tf .broadcast_to (pmf_length , context_shape )
282
- pmf_length = tf .reshape (pmf_length , [- 1 ])
283
- cdf_length = pmf_length + 2
284
- cdf_offset = tf .broadcast_to (minima , context_shape )
285
- cdf_offset = tf .reshape (cdf_offset , [- 1 ])
272
+ pmf_length = tf .broadcast_to (pmf_length , pmf_shape )
273
+ pmf_length = tf .reshape (pmf_length , [num_pmfs ])
274
+ cdf_offset = tf .broadcast_to (minima , pmf_shape )
275
+ cdf_offset = tf .reshape (cdf_offset , [num_pmfs ])
276
+ precision_tensor = tf .constant ([- precision ], dtype = tf .int32 )
286
277
287
278
# Prevent tensors from bouncing back and forth between host and GPU.
288
279
with tf .device ("/cpu:0" ):
289
- def loop_body (args ):
290
- prob , length = args
291
- prob = prob [:length ]
292
- overflow = tf .math .maximum (1 - tf .reduce_sum (prob , keepdims = True ), 0. )
293
- prob = tf .concat ([prob , overflow ], axis = 0 )
294
- cdf = gen_ops .pmf_to_quantized_cdf (
295
- tf .cast (prob , tf .float32 ), precision = self .range_coder_precision )
296
- return tf .pad (
297
- cdf , [[0 , max_length - length ]], mode = "CONSTANT" , constant_values = 0 )
298
-
299
- # TODO(jonycgn,ssjhv): Consider switching to Python control flow.
300
- cdf = tf .map_fn (
301
- loop_body , (pmf , pmf_length ), dtype = tf .int32 , name = "pmf_to_cdf" )
302
-
303
- return cdf , cdf_offset , cdf_length
280
+ def loop_body (i , cdf ):
281
+ p = pmf [i , :pmf_length [i ]]
282
+ overflow = tf .math .maximum (1. - tf .reduce_sum (p , keepdims = True ), 0. )
283
+ p = tf .cast (tf .concat ([p , overflow ], 0 ), tf .float32 )
284
+ c = gen_ops .pmf_to_quantized_cdf (p , precision = precision )
285
+ return i + 1 , tf .concat ([cdf , precision_tensor , c ], 0 )
286
+ i_0 = tf .constant (0 , tf .int32 )
287
+ cdf_0 = tf .constant ([], tf .int32 )
288
+ _ , cdf = tf .while_loop (
289
+ lambda i , _ : i < num_pmfs , loop_body , (i_0 , cdf_0 ),
290
+ shape_invariants = (i_0 .shape , tf .TensorShape ([None ])),
291
+ name = "pmf_to_cdf" )
292
+
293
+ return cdf , cdf_offset
304
294
305
295
def _log_prob (self , prior , bottleneck_perturbed ):
306
296
"""Evaluates prior.log_prob(bottleneck + noise)."""
@@ -341,8 +331,7 @@ def get_config(self):
341
331
stateless = False ,
342
332
expected_grads = self .expected_grads ,
343
333
tail_mass = self .tail_mass ,
344
- range_coder_precision = self .range_coder_precision ,
345
- cdf_shape = tuple (map (int , self .cdf .shape )),
334
+ cdf_shapes = (self .cdf .shape [0 ], self .cdf_offset .shape [0 ]),
346
335
dtype = self .dtype .name ,
347
336
laplace_tail_mass = self .laplace_tail_mass ,
348
337
)
0 commit comments