Skip to content

Commit 7666abc

Browse files
ssjhvcopybara-github
authored andcommitted
Added tf.control_dependencies() context managers around tf.debugging.assert* ops.
This is for better compatibility with TF1. However, there may be some side effects not executed in TF1 still. PiperOrigin-RevId: 485946240 Change-Id: Icb3a7f4cb5c9edc5de18b0751837f000c1f51287
1 parent c4c8622 commit 7666abc

File tree

4 files changed

+22
-15
lines changed

4 files changed

+22
-15
lines changed

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -301,17 +301,16 @@ def _log_prob(self, prior, bottleneck_perturbed):
301301
laplace_tail_mass = self.laplace_tail_mass
302302

303303
def mixture_log_prob_fn():
304-
tf.debugging.assert_less(
305-
laplace_tail_mass,
306-
tf.constant(1.0, prior.dtype),
307-
message="`laplace_tail_mass` must be less than 1.")
308-
laplace_prior = uniform_noise.NoisyLaplace(
309-
loc=tf.constant(0, dtype=prior.dtype),
310-
scale=tf.constant(1, dtype=prior.dtype))
311-
probs = prior.prob(bottleneck_perturbed)
312-
probs = ((1 - laplace_tail_mass) * probs +
313-
laplace_tail_mass *
314-
laplace_prior.prob(bottleneck_perturbed))
304+
with tf.control_dependencies([tf.debugging.assert_less(
305+
laplace_tail_mass, tf.constant(1.0, prior.dtype),
306+
message="`laplace_tail_mass` must be less than 1.")]):
307+
laplace_prior = uniform_noise.NoisyLaplace(
308+
loc=tf.constant(0, dtype=prior.dtype),
309+
scale=tf.constant(1, dtype=prior.dtype))
310+
probs = prior.prob(bottleneck_perturbed)
311+
probs = ((1 - laplace_tail_mass) * probs +
312+
laplace_tail_mass *
313+
laplace_prior.prob(bottleneck_perturbed))
315314
probs_too_small = probs < 1e-10
316315
probs_bounded = tf.maximum(probs, 1e-10)
317316
return tf.where(

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,9 @@ def decompress(self, strings, broadcast_shape):
410410
handle, decode_shape, self.cdf_offset.dtype)
411411
sanity = gen_ops.entropy_decode_finalize(handle)
412412
if self.decode_sanity_check:
413-
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
413+
with tf.control_dependencies([tf.debugging.assert_equal(
414+
sanity, True, message="Sanity check failed.")]):
415+
symbols = tf.identity(symbols)
414416
symbols += self.cdf_offset
415417
symbols = tf.reshape(symbols, output_shape)
416418
outputs = tf.cast(symbols, self.bottleneck_dtype)

tensorflow_compression/python/entropy_models/continuous_indexed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,9 @@ def decompress(self, strings, indexes):
410410
handle, flat_indexes, decode_shape, self.cdf_offset.dtype)
411411
sanity = gen_ops.entropy_decode_finalize(handle)
412412
if self.decode_sanity_check:
413-
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
413+
with tf.control_dependencies([tf.debugging.assert_equal(
414+
sanity, True, message="Sanity check failed.")]):
415+
symbols = tf.identity(symbols)
414416
symbols += tf.gather(self.cdf_offset, flat_indexes)
415417
return tf.cast(symbols, self.bottleneck_dtype)
416418

tensorflow_compression/python/entropy_models/universal.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,9 @@ def decompress(self, strings, broadcast_shape):
292292
handle, decode_indexes, decode_shape, self.cdf_offset.dtype)
293293
sanity = gen_ops.entropy_decode_finalize(handle)
294294
if self.decode_sanity_check:
295-
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
295+
with tf.control_dependencies([tf.debugging.assert_equal(
296+
sanity, True, message="Sanity check failed.")]):
297+
symbols = tf.identity(symbols)
296298
symbols += tf.gather(self.cdf_offset, indexes)
297299
outputs = tf.cast(symbols, self.bottleneck_dtype)
298300
return outputs + offset
@@ -589,7 +591,9 @@ def decompress(self, strings, indexes):
589591
handle, flat_indexes, decode_shape, self.cdf_offset.dtype)
590592
sanity = gen_ops.entropy_decode_finalize(handle)
591593
if self.decode_sanity_check:
592-
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
594+
with tf.control_dependencies([tf.debugging.assert_equal(
595+
sanity, True, message="Sanity check failed.")]):
596+
symbols = tf.identity(symbols)
593597
symbols += tf.gather(self.cdf_offset, flat_indexes)
594598
offset = self._offset_from_indexes(indexes)
595599
return tf.cast(symbols, self.bottleneck_dtype) + offset

0 commit comments

Comments
 (0)