@@ -32,21 +32,22 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
32
32
33
33
This entropy model handles quantization of a bottleneck tensor and helps with
34
34
training of the parameters of the probability distribution modeling the
35
- tensor. It also pre-computes integer probability tables, which can then be
36
- used to compress and decompress bottleneck tensors reliably across different
37
- platforms.
35
+ tensor (a shared "prior" between sender and receiver). It also pre-computes
36
+ integer probability tables, which can then be used to compress and decompress
37
+ bottleneck tensors reliably across different platforms.
38
38
39
39
A typical workflow looks like this:
40
40
41
- - Train a model using this entropy model as a bottleneck, passing the
42
- bottleneck tensor through `quantize()` while optimizing compressibility of
43
- the tensor using `bits()`. `bits(training=True)` computes a differentiable
44
- upper bound on the number of bits needed to compress the bottleneck tensor.
41
+ - Train a model using an instance of this entropy model as a bottleneck,
42
+ passing the bottleneck tensor through `quantize()` while optimizing
43
+ compressibility of the tensor using `bits()`. `bits(training=True)` computes
44
+ a differentiable upper bound on the number of bits needed to compress the
45
+ bottleneck tensor.
45
46
- For evaluation, get a closer estimate of the number of compressed bits
46
47
using `bits(training=False)`.
47
- - Call `update_tables()` to ensure the probability tables for range coding are
48
- up-to-date.
49
- - Share the model between a sender and a receiver.
48
+ - Instantiate an entropy model with `compression=True` (and the same
49
+ parameters as during training), and share the model between a sender and a
50
+ receiver.
50
51
- On the sender side, compute the bottleneck tensor and call `compress()` on
51
52
it. The output is a compressed string representation of the tensor. Transmit
52
53
the string to the receiver, and call `decompress()` there. The output is the
@@ -56,9 +57,9 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
56
57
This class assumes that all scalar elements of the encoded tensor are
57
58
statistically independent, and that the parameters of their scalar
58
59
distributions do not depend on data. The innermost dimensions of the
59
- bottleneck tensor must be broadcastable to the batch shape of `distribution`.
60
- Any dimensions to the left of the batch shape are assumed to be i.i.d., i.e.
61
- the likelihoods are broadcast to the bottleneck tensor accordingly.
60
+ bottleneck tensor must be broadcastable to the batch shape of `prior`. Any
61
+ dimensions to the left of the batch shape are assumed to be i.i.d., i.e. the
62
+ likelihoods are broadcast to the bottleneck tensor accordingly.
62
63
63
64
A more detailed description (and motivation) of this way of performing
64
65
quantization and range coding can be found in the following paper. Please cite
@@ -69,38 +70,44 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
69
70
> https://openreview.net/forum?id=rJxdQ3jeg
70
71
"""
71
72
72
- def __init__ (self , distribution , coding_rank ,
73
+ def __init__ (self , prior , coding_rank , compression = False ,
73
74
likelihood_bound = 1e-9 , tail_mass = 2 ** - 8 ,
74
75
range_coder_precision = 12 ):
75
76
"""Initializer.
76
77
77
78
Arguments:
78
- distribution: A `tfp.distributions.Distribution` object modeling the
79
- distribution of the bottleneck tensor values including additive uniform
80
- noise. The distribution parameters may not depend on data (they must be
81
- trainable variables or constants). For best results, the distribution
82
- should be flexible enough to have a unit-width uniform distribution as a
83
- special case, since this is the distribution an element will take on
84
- when its bottleneck value is constant (due to the additive noise).
79
+ prior: A `tfp.distributions.Distribution` object. A density model fitting
80
+ the marginal distribution of the bottleneck data with additive uniform
81
+ noise, which is shared a priori between the sender and the receiver. For
82
+ best results, the distribution should be flexible enough to have a
83
+ unit-width uniform distribution as a special case, since this is the
84
+ marginal distribution for bottleneck dimensions that are constant. The
85
+ distribution parameters may not depend on data (they must be either
86
+ variables or constants).
85
87
coding_rank: Integer. Number of innermost dimensions considered a coding
86
88
unit. Each coding unit is compressed to its own bit string, and the
87
89
`bits()` method sums over each coding unit.
90
+ compression: Boolean. If set to `True`, the range coding tables
91
+ used by `compress()` and `decompress()` will be built on instantiation.
92
+ Otherwise, some computation can be saved, but these two methods will not
93
+ be accessible.
88
94
likelihood_bound: Float. Lower bound for likelihood values, to prevent
89
95
training instabilities.
90
96
tail_mass: Float. Approximate probability mass which is range encoded with
91
97
less precision, by using a Golomb-like code.
92
98
range_coder_precision: Integer. Precision passed to the range coding op.
93
99
"""
94
- if coding_rank < distribution .batch_shape .rank :
100
+ if coding_rank < prior .batch_shape .rank :
95
101
raise ValueError (
96
- "`coding_rank` can't be smaller than batch rank of `distribution` ." )
102
+ "`coding_rank` can't be smaller than batch rank of prior ." )
97
103
super ().__init__ (
98
- distribution , coding_rank , likelihood_bound = likelihood_bound ,
99
- tail_mass = tail_mass , range_coder_precision = range_coder_precision )
104
+ prior , coding_rank , compression = compression ,
105
+ likelihood_bound = likelihood_bound , tail_mass = tail_mass ,
106
+ range_coder_precision = range_coder_precision )
100
107
101
108
def _compute_indexes (self , broadcast_shape ):
102
109
# TODO(jonycgn, ssjhv): Investigate broadcasting in range coding op.
103
- dist_shape = self .distribution .batch_shape_tensor ()
110
+ dist_shape = self .prior .batch_shape_tensor ()
104
111
indexes = tf .range (tf .reduce_prod (dist_shape ), dtype = tf .int32 )
105
112
indexes = tf .reshape (indexes , dist_shape )
106
113
indexes = tf .broadcast_to (
@@ -113,9 +120,9 @@ def bits(self, bottleneck, training=True):
113
120
Arguments:
114
121
bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
115
122
least `self.coding_rank` dimensions, and the innermost dimensions must
116
- be broadcastable to `self.distribution .batch_shape`.
123
+ be broadcastable to `self.prior .batch_shape`.
117
124
training: Boolean. If `False`, computes the Shannon information of
118
- `bottleneck` under the distribution `self.distribution `, which is a
125
+ `bottleneck` under the distribution `self.prior `, which is a
119
126
non-differentiable, tight *lower* bound on the number of bits needed to
120
127
compress `bottleneck` using `compress()`. If `True`, returns a somewhat
121
128
looser, but differentiable *upper* bound on this quantity.
@@ -129,7 +136,7 @@ def bits(self, bottleneck, training=True):
129
136
tf .shape (bottleneck ), minval = - .5 , maxval = .5 , dtype = bottleneck .dtype )
130
137
else :
131
138
quantized = self .quantize (bottleneck )
132
- probs = self .distribution .prob (quantized )
139
+ probs = self .prior .prob (quantized )
133
140
probs = math_ops .lower_bound (probs , self .likelihood_bound )
134
141
axes = tuple (range (- self .coding_rank , 0 ))
135
142
bits = tf .reduce_sum (tf .math .log (probs ), axis = axes ) / - tf .math .log (2. )
@@ -140,7 +147,7 @@ def quantize(self, bottleneck):
140
147
141
148
To use this entropy model as an information bottleneck during training, pass
142
149
a tensor through this function. The tensor is rounded to integer values
143
- modulo `self.quantization_offset`, which depends on `self.distribution `. For
150
+ modulo `self.quantization_offset`, which depends on `self.prior `. For
144
151
instance, for a Gaussian distribution, the returned values are rounded to
145
152
the location of the mode of the distribution plus or minus an integer.
146
153
@@ -149,7 +156,7 @@ def quantize(self, bottleneck):
149
156
150
157
Arguments:
151
158
bottleneck: `tf.Tensor` containing the data to be quantized. The innermost
152
- dimensions must be broadcastable to `self.distribution .batch_shape`.
159
+ dimensions must be broadcastable to `self.prior .batch_shape`.
153
160
154
161
Returns:
155
162
A `tf.Tensor` containing the quantized values.
@@ -162,7 +169,7 @@ def compress(self, bottleneck):
162
169
163
170
Compresses the tensor to bit strings. `bottleneck` is first quantized
164
171
as in `quantize()`, and then compressed using the probability tables derived
165
- from `self.distribution `. The quantized tensor can later be recovered by
172
+ from `self.prior `. The quantized tensor can later be recovered by
166
173
calling `decompress()`.
167
174
168
175
The innermost `self.coding_rank` dimensions are treated as one coding unit,
@@ -172,7 +179,7 @@ def compress(self, bottleneck):
172
179
Arguments:
173
180
bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
174
181
least `self.coding_rank` dimensions, and the innermost dimensions must
175
- be broadcastable to `self.distribution .batch_shape`.
182
+ be broadcastable to `self.prior .batch_shape`.
176
183
177
184
Returns:
178
185
A `tf.Tensor` having the same shape as `bottleneck` without the
@@ -184,7 +191,7 @@ def compress(self, bottleneck):
184
191
batch_shape , coding_shape = tf .split (
185
192
input_shape , [input_rank - self .coding_rank , self .coding_rank ])
186
193
broadcast_shape = coding_shape [
187
- :self .coding_rank - self .distribution .batch_shape .rank ]
194
+ :self .coding_rank - self .prior .batch_shape .rank ]
188
195
189
196
indexes = self ._compute_indexes (broadcast_shape )
190
197
offset = self .quantization_offset ()
@@ -196,7 +203,7 @@ def compress(self, bottleneck):
196
203
def loop_body (symbols ):
197
204
return range_coding_ops .unbounded_index_range_encode (
198
205
symbols , indexes ,
199
- self ._cdf , self ._cdf_length , self ._cdf_offset ,
206
+ self .cdf , self .cdf_length , self .cdf_offset ,
200
207
precision = self .range_coder_precision ,
201
208
overflow_width = 4 , debug_level = 1 )
202
209
@@ -217,15 +224,15 @@ def decompress(self, strings, broadcast_shape):
217
224
strings: `tf.Tensor` containing the compressed bit strings.
218
225
broadcast_shape: Iterable of ints. The part of the output tensor shape
219
226
between the shape of `strings` on the left and
220
- `self.distribution .batch_shape` on the right. This must match the shape
227
+ `self.prior .batch_shape` on the right. This must match the shape
221
228
of the input to `compress()`.
222
229
223
230
Returns:
224
231
A `tf.Tensor` of shape `strings.shape + broadcast_shape +
225
- self.distribution .batch_shape`.
232
+ self.prior .batch_shape`.
226
233
"""
227
234
batch_shape = tf .shape (strings )
228
- dist_shape = self .distribution .batch_shape_tensor ()
235
+ dist_shape = self .prior .batch_shape_tensor ()
229
236
symbols_shape = tf .concat ([batch_shape , broadcast_shape , dist_shape ], 0 )
230
237
231
238
indexes = self ._compute_indexes (broadcast_shape )
@@ -236,7 +243,7 @@ def decompress(self, strings, broadcast_shape):
236
243
def loop_body (string ):
237
244
return range_coding_ops .unbounded_index_range_decode (
238
245
string , indexes ,
239
- self ._cdf , self ._cdf_length , self ._cdf_offset ,
246
+ self .cdf , self .cdf_length , self .cdf_offset ,
240
247
precision = self .range_coder_precision ,
241
248
overflow_width = 4 , debug_level = 1 )
242
249
0 commit comments