Skip to content

Commit cb1dee3

Browse files
Johannes Ballécopybara-github
authored andcommitted
Cleans up a few instances of Tensor type casting.
Adds an attribute `prior_shape_tensor`, analogous to the shape attributes on tfp Distributions. PiperOrigin-RevId: 320462150 Change-Id: I5660cfcaf61db231054481baf14c5837903fc0c7
1 parent 9cbc850 commit cb1dee3

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ def prior_shape(self):
122122
"""Batch shape of `prior` (dimensions which are not assumed i.i.d.)."""
123123
return self._prior_shape
124124

125+
@property
126+
def prior_shape_tensor(self):
127+
"""Batch shape of `prior` as a `Tensor`."""
128+
return tf.constant(self.prior_shape, dtype=tf.int32)
129+
125130
@property
126131
def coding_rank(self):
127132
"""Number of innermost dimensions considered a coding unit."""
@@ -217,10 +222,10 @@ def _build_tables(self, prior):
217222
pmf = tf.reshape(pmf, [max_length, -1])
218223
pmf = tf.transpose(pmf)
219224

220-
pmf_length = tf.broadcast_to(pmf_length, self.prior_shape)
225+
pmf_length = tf.broadcast_to(pmf_length, self.prior_shape_tensor)
221226
pmf_length = tf.reshape(pmf_length, [-1])
222227
cdf_length = pmf_length + 2
223-
cdf_offset = tf.broadcast_to(-minima, self.prior_shape)
228+
cdf_offset = tf.broadcast_to(-minima, self.prior_shape_tensor)
224229
cdf_offset = tf.reshape(cdf_offset, [-1])
225230

226231
# Prevent tensors from bouncing back and forth between host and GPU.

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(self, prior, coding_rank, compression=False,
124124
quantization_offset = None
125125
else:
126126
quantization_offset = tf.broadcast_to(
127-
quantization_offset, self.prior_shape)
127+
quantization_offset, self.prior_shape_tensor)
128128
quantization_offset = tf.Variable(
129129
quantization_offset, trainable=False, name="quantization_offset")
130130
self._quantization_offset = quantization_offset
@@ -133,9 +133,9 @@ def _compute_indexes(self, broadcast_shape):
133133
# TODO(jonycgn, ssjhv): Investigate broadcasting in range coding op.
134134
prior_size = functools.reduce(lambda x, y: x * y, self.prior_shape, 1)
135135
indexes = tf.range(prior_size, dtype=tf.int32)
136-
indexes = tf.reshape(indexes, self.prior_shape)
136+
indexes = tf.reshape(indexes, self.prior_shape_tensor)
137137
indexes = tf.broadcast_to(
138-
indexes, tf.concat([broadcast_shape, self.prior_shape], 0))
138+
indexes, tf.concat([broadcast_shape, self.prior_shape_tensor], 0))
139139
return indexes
140140

141141
@tf.Module.with_name_scope
@@ -164,7 +164,8 @@ def bits(self, bottleneck, training=True):
164164
probs = self.prior.prob(quantized)
165165
probs = math_ops.lower_bound(probs, self.likelihood_bound)
166166
axes = tuple(range(-self.coding_rank, 0))
167-
bits = tf.reduce_sum(tf.math.log(probs), axis=axes) / -tf.math.log(2.)
167+
bits = tf.reduce_sum(tf.math.log(probs), axis=axes) / (
168+
-tf.math.log(tf.constant(2., dtype=probs.dtype)))
168169
return bits
169170

170171
@tf.Module.with_name_scope
@@ -265,7 +266,7 @@ def decompress(self, strings, broadcast_shape):
265266
broadcast_shape = tf.convert_to_tensor(broadcast_shape, dtype=tf.int32)
266267
batch_shape = tf.shape(strings)
267268
symbols_shape = tf.concat(
268-
[batch_shape, broadcast_shape, self.prior_shape], 0)
269+
[batch_shape, broadcast_shape, self.prior_shape_tensor], 0)
269270

270271
indexes = self._compute_indexes(broadcast_shape)
271272
strings = tf.reshape(strings, [-1])
@@ -321,7 +322,7 @@ def from_config(cls, config):
321322
with self.name_scope:
322323
# pylint:disable=protected-access
323324
if config["quantization_offset"]:
324-
zeros = tf.zeros(self.prior_shape, dtype=self.dtype)
325+
zeros = tf.zeros(self.prior_shape_tensor, dtype=self.dtype)
325326
self._quantization_offset = tf.Variable(
326327
zeros, name="quantization_offset")
327328
else:

tensorflow_compression/python/entropy_models/continuous_indexed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,8 @@ def bits(self, bottleneck, indexes, training=True):
282282
probs = prior.prob(quantized)
283283
probs = math_ops.lower_bound(probs, self.likelihood_bound)
284284
axes = tuple(range(-self.coding_rank, 0))
285-
bits = tf.reduce_sum(tf.math.log(probs), axis=axes) / -tf.math.log(2.)
285+
bits = tf.reduce_sum(tf.math.log(probs), axis=axes) / (
286+
-tf.math.log(tf.constant(2., dtype=probs.dtype)))
286287
return bits
287288

288289
@tf.Module.with_name_scope

0 commit comments

Comments
 (0)