Skip to content

Commit acbf653

Browse files
author
Johannes Ballé
committed
Fix quantized CDF shape problem.
1 parent accedfc commit acbf653

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

tensorflow_compression/python/layers/entropy_models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,17 @@ def quantiles_initializer(shape, dtype=None, partition_info=None):
359359

360360
cdf = coder_ops.pmf_to_quantized_cdf(
361361
pmf, precision=self.range_coder_precision)
362+
363+
# We need to supply an initializer without fully defined static shape here,
364+
# or the variable will return the wrong dynamic shape later. A placeholder
365+
# with default gets the trick done.
366+
def cdf_init(*args, **kwargs):
367+
return array_ops.placeholder_with_default(
368+
array_ops.zeros((channels, 1), dtype=dtypes.int32),
369+
shape=(channels, None))
370+
362371
self._quantized_cdf = self.add_variable(
363-
"quantized_cdf", shape=(channels, 1), dtype=dtypes.int32,
372+
"quantized_cdf", shape=None, initializer=cdf_init, dtype=dtypes.int32,
364373
trainable=False)
365374

366375
update_op = state_ops.assign(

tensorflow_compression/python/layers/entropy_models_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_normalization(self):
288288
likelihood, = sess.run([likelihood], {inputs: x})
289289
self.assertEqual(x.shape, likelihood.shape)
290290
integral = np.sum(likelihood) * .001
291-
self.assertAllClose(1, integral, rtol=0, atol=1e-4)
291+
self.assertAllClose(1, integral, rtol=0, atol=2e-4)
292292

293293
def test_entropy_estimates(self):
294294
# Test that entropy estimates match actual range coding.

0 commit comments

Comments
 (0)