@@ -82,7 +82,8 @@ def __init__(self,
82
82
range_coder_precision = 12 ,
83
83
bottleneck_dtype = None ,
84
84
num_noise_levels = 15 ,
85
- stateless = False ):
85
+ stateless = False ,
86
+ decode_sanity_check = True ):
86
87
"""Initializes the instance.
87
88
88
89
Args:
@@ -118,6 +119,8 @@ def __init__(self,
118
119
allows it to be constructed within a `tf.function` body. If
119
120
`compression=False`, then `stateless=True` is implied and the provided
120
121
value is ignored.
122
+ decode_sanity_check: Boolean. If `True`, an raises an error if the binary
123
+ strings passed into `decompress` are not completely decoded.
121
124
"""
122
125
if prior .event_shape .rank :
123
126
raise ValueError ("`prior` must be a (batch of) scalar distribution(s)." )
@@ -135,6 +138,7 @@ def __init__(self,
135
138
self ._num_noise_levels = num_noise_levels
136
139
if self .coding_rank < self .prior_shape .rank :
137
140
raise ValueError ("`coding_rank` can't be smaller than `prior_shape`." )
141
+ self .decode_sanity_check = decode_sanity_check
138
142
139
143
with self .name_scope :
140
144
if self .compression :
@@ -285,7 +289,8 @@ def decompress(self, strings, broadcast_shape):
285
289
handle , symbols = gen_ops .entropy_decode_index (
286
290
handle , decode_indexes , decode_shape , self .cdf_offset .dtype )
287
291
sanity = gen_ops .entropy_decode_finalize (handle )
288
- tf .debugging .assert_equal (sanity , True , message = "Sanity check failed." )
292
+ if self .decode_sanity_check :
293
+ tf .debugging .assert_equal (sanity , True , message = "Sanity check failed." )
289
294
symbols += tf .gather (self .cdf_offset , indexes )
290
295
outputs = tf .cast (symbols , self .bottleneck_dtype )
291
296
return outputs + offset
@@ -321,7 +326,8 @@ def __init__(self,
321
326
bottleneck_dtype = None ,
322
327
prior_dtype = tf .float32 ,
323
328
stateless = False ,
324
- num_noise_levels = 15 ):
329
+ num_noise_levels = 15 ,
330
+ decode_sanity_check = True ):
325
331
"""Initializes the instance.
326
332
327
333
Args:
@@ -364,6 +370,8 @@ def __init__(self,
364
370
rather than `Variable`s.
365
371
num_noise_levels: Integer. The number of levels used to quantize the
366
372
uniform noise.
373
+ decode_sanity_check: Boolean. If `True`, an raises an error if the binary
374
+ strings passed into `decompress` are not completely decoded.
367
375
"""
368
376
if coding_rank <= 0 :
369
377
raise ValueError ("`coding_rank` must be larger than 0." )
@@ -393,6 +401,7 @@ def __init__(self,
393
401
self ._parameter_fns = dict (parameter_fns )
394
402
self ._prior_dtype = tf .as_dtype (prior_dtype )
395
403
self ._num_noise_levels = num_noise_levels
404
+ self .decode_sanity_check = decode_sanity_check
396
405
397
406
with self .name_scope :
398
407
if self .compression :
@@ -577,7 +586,8 @@ def decompress(self, strings, indexes):
577
586
handle , symbols = gen_ops .entropy_decode_index (
578
587
handle , flat_indexes , decode_shape , self .cdf_offset .dtype )
579
588
sanity = gen_ops .entropy_decode_finalize (handle )
580
- tf .debugging .assert_equal (sanity , True , message = "Sanity check failed." )
589
+ if self .decode_sanity_check :
590
+ tf .debugging .assert_equal (sanity , True , message = "Sanity check failed." )
581
591
symbols += tf .gather (self .cdf_offset , flat_indexes )
582
592
offset = self ._offset_from_indexes (indexes )
583
593
return tf .cast (symbols , self .bottleneck_dtype ) + offset
0 commit comments