Skip to content

Commit 2cb1a2b

Browse files
Johannes Ballécopybara-github
authored andcommitted
Allow disabling sanity check.
This makes it possible to feed random bits into the decoder to treat it like a generator and produce samples from the learned data model. PiperOrigin-RevId: 438913987 Change-Id: Ia46879dff277b82cafaa806555658edf89a4fe84
1 parent 8f4cc21 commit 2cb1a2b

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(self,
122122
cdf_shapes=None,
123123
offset_heuristic=True,
124124
quantization_offset=None,
125+
decode_sanity_check=True,
125126
laplace_tail_mass=0):
126127
"""Initializes the instance.
127128
@@ -168,6 +169,8 @@ def __init__(self,
168169
if you are using soft quantization during training.
169170
quantization_offset: `tf.Tensor` or `None`. The quantization offsets to
170171
use. If provided (not `None`), then `offset_heuristic` is ineffective.
172+
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
173+
strings passed into `decompress` are not completely decoded.
171174
laplace_tail_mass: Float. If positive, will augment the prior with a
172175
Laplace mixture for training stability. (experimental)
173176
"""
@@ -197,6 +200,7 @@ def __init__(self,
197200
prior_shape if prior is None else prior.batch_shape)
198201
if self.coding_rank < self.prior_shape.rank:
199202
raise ValueError("`coding_rank` can't be smaller than `prior_shape`.")
203+
self.decode_sanity_check = decode_sanity_check
200204

201205
with self.name_scope:
202206
if cdf_shapes is not None:
@@ -402,7 +406,8 @@ def decompress(self, strings, broadcast_shape):
402406
handle, symbols = gen_ops.entropy_decode_channel(
403407
handle, decode_shape, self.cdf_offset.dtype)
404408
sanity = gen_ops.entropy_decode_finalize(handle)
405-
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
409+
if self.decode_sanity_check:
410+
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
406411
symbols += self.cdf_offset
407412
symbols = tf.reshape(symbols, output_shape)
408413
outputs = tf.cast(symbols, self.bottleneck_dtype)

tensorflow_compression/python/entropy_models/continuous_indexed.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(self,
138138
range_coder_precision=12,
139139
bottleneck_dtype=None,
140140
prior_dtype=tf.float32,
141+
decode_sanity_check=True,
141142
laplace_tail_mass=0):
142143
"""Initializes the instance.
143144
@@ -186,6 +187,8 @@ def __init__(self,
186187
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
187188
prior_dtype: `tf.dtypes.DType`. Data type of prior and probability
188189
computations. Defaults to `tf.float32`.
190+
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
191+
strings passed into `decompress` are not completely decoded.
189192
laplace_tail_mass: Float. If positive, will augment the prior with a
190193
laplace mixture for training stability. (experimental)
191194
"""
@@ -216,6 +219,7 @@ def __init__(self,
216219
self._prior_fn = prior_fn
217220
self._parameter_fns = dict(parameter_fns)
218221
self._prior_dtype = tf.as_dtype(prior_dtype)
222+
self.decode_sanity_check = decode_sanity_check
219223

220224
with self.name_scope:
221225
if self.compression:
@@ -404,7 +408,8 @@ def decompress(self, strings, indexes):
404408
handle, symbols = gen_ops.entropy_decode_index(
405409
handle, flat_indexes, decode_shape, self.cdf_offset.dtype)
406410
sanity = gen_ops.entropy_decode_finalize(handle)
407-
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
411+
if self.decode_sanity_check:
412+
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
408413
symbols += tf.gather(self.cdf_offset, flat_indexes)
409414
return tf.cast(symbols, self.bottleneck_dtype)
410415

tensorflow_compression/python/entropy_models/universal.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def __init__(self,
8282
range_coder_precision=12,
8383
bottleneck_dtype=None,
8484
num_noise_levels=15,
85-
stateless=False):
85+
stateless=False,
86+
decode_sanity_check=True):
8687
"""Initializes the instance.
8788
8889
Args:
@@ -118,6 +119,8 @@ def __init__(self,
118119
allows it to be constructed within a `tf.function` body. If
119120
`compression=False`, then `stateless=True` is implied and the provided
120121
value is ignored.
122+
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
123+
strings passed into `decompress` are not completely decoded.
121124
"""
122125
if prior.event_shape.rank:
123126
raise ValueError("`prior` must be a (batch of) scalar distribution(s).")
@@ -135,6 +138,7 @@ def __init__(self,
135138
self._num_noise_levels = num_noise_levels
136139
if self.coding_rank < self.prior_shape.rank:
137140
raise ValueError("`coding_rank` can't be smaller than `prior_shape`.")
141+
self.decode_sanity_check = decode_sanity_check
138142

139143
with self.name_scope:
140144
if self.compression:
@@ -285,7 +289,8 @@ def decompress(self, strings, broadcast_shape):
285289
handle, symbols = gen_ops.entropy_decode_index(
286290
handle, decode_indexes, decode_shape, self.cdf_offset.dtype)
287291
sanity = gen_ops.entropy_decode_finalize(handle)
288-
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
292+
if self.decode_sanity_check:
293+
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
289294
symbols += tf.gather(self.cdf_offset, indexes)
290295
outputs = tf.cast(symbols, self.bottleneck_dtype)
291296
return outputs + offset
@@ -321,7 +326,8 @@ def __init__(self,
321326
bottleneck_dtype=None,
322327
prior_dtype=tf.float32,
323328
stateless=False,
324-
num_noise_levels=15):
329+
num_noise_levels=15,
330+
decode_sanity_check=True):
325331
"""Initializes the instance.
326332
327333
Args:
@@ -364,6 +370,8 @@ def __init__(self,
364370
rather than `Variable`s.
365371
num_noise_levels: Integer. The number of levels used to quantize the
366372
uniform noise.
373+
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
374+
strings passed into `decompress` are not completely decoded.
367375
"""
368376
if coding_rank <= 0:
369377
raise ValueError("`coding_rank` must be larger than 0.")
@@ -393,6 +401,7 @@ def __init__(self,
393401
self._parameter_fns = dict(parameter_fns)
394402
self._prior_dtype = tf.as_dtype(prior_dtype)
395403
self._num_noise_levels = num_noise_levels
404+
self.decode_sanity_check = decode_sanity_check
396405

397406
with self.name_scope:
398407
if self.compression:
@@ -577,7 +586,8 @@ def decompress(self, strings, indexes):
577586
handle, symbols = gen_ops.entropy_decode_index(
578587
handle, flat_indexes, decode_shape, self.cdf_offset.dtype)
579588
sanity = gen_ops.entropy_decode_finalize(handle)
580-
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
589+
if self.decode_sanity_check:
590+
tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
581591
symbols += tf.gather(self.cdf_offset, flat_indexes)
582592
offset = self._offset_from_indexes(indexes)
583593
return tf.cast(symbols, self.bottleneck_dtype) + offset

0 commit comments

Comments
 (0)