@@ -75,7 +75,7 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
75
75
76
76
def __init__ (self , prior , coding_rank , compression = False ,
77
77
likelihood_bound = 1e-9 , tail_mass = 2 ** - 8 ,
78
- range_coder_precision = 12 ):
78
+ range_coder_precision = 12 , no_variables = False ):
79
79
"""Initializer.
80
80
81
81
Arguments:
@@ -98,6 +98,8 @@ def __init__(self, prior, coding_rank, compression=False,
98
98
tail_mass: Float. Approximate probability mass which is range encoded with
99
99
less precision, by using a Golomb-like code.
100
100
range_coder_precision: Integer. Precision passed to the range coding op.
101
+ no_variables: Boolean. If True, creates range coding tables as `Tensor`s
102
+ rather than `Variable`s.
101
103
102
104
Raises:
103
105
RuntimeError: when attempting to instantiate an entropy model with
@@ -107,27 +109,38 @@ def __init__(self, prior, coding_rank, compression=False,
107
109
raise ValueError (
108
110
"`coding_rank` can't be smaller than batch rank of prior." )
109
111
super ().__init__ (
110
- prior , coding_rank , compression = compression ,
111
- likelihood_bound = likelihood_bound , tail_mass = tail_mass ,
112
- range_coder_precision = range_coder_precision )
112
+ prior = prior ,
113
+ coding_rank = coding_rank ,
114
+ compression = compression ,
115
+ likelihood_bound = likelihood_bound ,
116
+ tail_mass = tail_mass ,
117
+ range_coder_precision = range_coder_precision ,
118
+ no_variables = no_variables ,
119
+ )
113
120
114
121
quantization_offset = helpers .quantization_offset (prior )
115
- if self . compression :
116
- # Optimization: if the quantization offset is zero, we don't need to
117
- # subtract/add it when quantizing, and we don't need to serialize its
118
- # value. Note that this code will only work in eager mode.
119
- # TODO(jonycgn): Reconsider if this optimization is worth keeping once
120
- # the implementation is stable.
121
- if tf .executing_eagerly () and tf . reduce_all (
122
- tf . equal ( quantization_offset , 0. )):
123
- quantization_offset = None
124
- else :
125
- quantization_offset = tf . broadcast_to (
126
- quantization_offset , self .prior_shape_tensor )
122
+ # Optimization: if the quantization offset is zero, we don't need to
123
+ # subtract/add it when quantizing, and we don't need to serialize its value.
124
+ # Note that this code will only work in eager mode.
125
+ # TODO(jonycgn): Reconsider if this optimization is worth keeping once the
126
+ # implementation is stable.
127
+ if tf . executing_eagerly () and tf . reduce_all (
128
+ tf .equal ( quantization_offset , 0. )):
129
+ quantization_offset = None
130
+ else :
131
+ quantization_offset = tf . broadcast_to (
132
+ quantization_offset , self . prior_shape_tensor )
133
+ if self . compression and not self .no_variables :
127
134
quantization_offset = tf .Variable (
128
135
quantization_offset , trainable = False , name = "quantization_offset" )
129
136
self ._quantization_offset = quantization_offset
130
137
138
+ @property
139
+ def quantization_offset (self ):
140
+ if self ._quantization_offset is None :
141
+ return None
142
+ return tf .convert_to_tensor (self ._quantization_offset )
143
+
131
144
def _compute_indexes (self , broadcast_shape ):
132
145
# TODO(jonycgn, ssjhv): Investigate broadcasting in range coding op.
133
146
prior_size = functools .reduce (lambda x , y : x * y , self .prior_shape , 1 )
@@ -187,7 +200,7 @@ def quantize(self, bottleneck):
187
200
Returns:
188
201
A `tf.Tensor` containing the quantized values.
189
202
"""
190
- return self ._quantize (bottleneck , self ._quantization_offset )
203
+ return self ._quantize (bottleneck , self .quantization_offset )
191
204
192
205
@tf .Module .with_name_scope
193
206
def compress (self , bottleneck ):
@@ -220,8 +233,8 @@ def compress(self, bottleneck):
220
233
:self .coding_rank - len (self .prior_shape )]
221
234
222
235
indexes = self ._compute_indexes (broadcast_shape )
223
- if self ._quantization_offset is not None :
224
- bottleneck -= self ._quantization_offset
236
+ if self .quantization_offset is not None :
237
+ bottleneck -= self .quantization_offset
225
238
symbols = tf .cast (tf .round (bottleneck ), tf .int32 )
226
239
symbols = tf .reshape (symbols , tf .concat ([[- 1 ], coding_shape ], 0 ))
227
240
@@ -287,8 +300,8 @@ def loop_body(string):
287
300
288
301
symbols = tf .reshape (symbols , symbols_shape )
289
302
outputs = tf .cast (symbols , self .dtype )
290
- if self ._quantization_offset is not None :
291
- outputs += self ._quantization_offset
303
+ if self .quantization_offset is not None :
304
+ outputs += self .quantization_offset
292
305
return outputs
293
306
294
307
def get_config (self ):
0 commit comments