@@ -124,7 +124,7 @@ def __init__(self, prior, coding_rank, compression=False,
124
124
quantization_offset = None
125
125
else :
126
126
quantization_offset = tf .broadcast_to (
127
- quantization_offset , self .prior_shape )
127
+ quantization_offset , self .prior_shape_tensor )
128
128
quantization_offset = tf .Variable (
129
129
quantization_offset , trainable = False , name = "quantization_offset" )
130
130
self ._quantization_offset = quantization_offset
@@ -133,9 +133,9 @@ def _compute_indexes(self, broadcast_shape):
133
133
# TODO(jonycgn, ssjhv): Investigate broadcasting in range coding op.
134
134
prior_size = functools .reduce (lambda x , y : x * y , self .prior_shape , 1 )
135
135
indexes = tf .range (prior_size , dtype = tf .int32 )
136
- indexes = tf .reshape (indexes , self .prior_shape )
136
+ indexes = tf .reshape (indexes , self .prior_shape_tensor )
137
137
indexes = tf .broadcast_to (
138
- indexes , tf .concat ([broadcast_shape , self .prior_shape ], 0 ))
138
+ indexes , tf .concat ([broadcast_shape , self .prior_shape_tensor ], 0 ))
139
139
return indexes
140
140
141
141
@tf .Module .with_name_scope
@@ -164,7 +164,8 @@ def bits(self, bottleneck, training=True):
164
164
probs = self .prior .prob (quantized )
165
165
probs = math_ops .lower_bound (probs , self .likelihood_bound )
166
166
axes = tuple (range (- self .coding_rank , 0 ))
167
- bits = tf .reduce_sum (tf .math .log (probs ), axis = axes ) / - tf .math .log (2. )
167
+ bits = tf .reduce_sum (tf .math .log (probs ), axis = axes ) / (
168
+ - tf .math .log (tf .constant (2. , dtype = probs .dtype )))
168
169
return bits
169
170
170
171
@tf .Module .with_name_scope
@@ -265,7 +266,7 @@ def decompress(self, strings, broadcast_shape):
265
266
broadcast_shape = tf .convert_to_tensor (broadcast_shape , dtype = tf .int32 )
266
267
batch_shape = tf .shape (strings )
267
268
symbols_shape = tf .concat (
268
- [batch_shape , broadcast_shape , self .prior_shape ], 0 )
269
+ [batch_shape , broadcast_shape , self .prior_shape_tensor ], 0 )
269
270
270
271
indexes = self ._compute_indexes (broadcast_shape )
271
272
strings = tf .reshape (strings , [- 1 ])
@@ -321,7 +322,7 @@ def from_config(cls, config):
321
322
with self .name_scope :
322
323
# pylint:disable=protected-access
323
324
if config ["quantization_offset" ]:
324
- zeros = tf .zeros (self .prior_shape , dtype = self .dtype )
325
+ zeros = tf .zeros (self .prior_shape_tensor , dtype = self .dtype )
325
326
self ._quantization_offset = tf .Variable (
326
327
zeros , name = "quantization_offset" )
327
328
else :
0 commit comments