Skip to content

Commit e4263cd

Browse files
Johannes Ballécopybara-github
authored andcommitted
Explicitly fall back on floatx() in case compute_dtype is None.
PiperOrigin-RevId: 427458260 Change-Id: I997f6bf9a3c4f466e3a952b0c29320c15a918abb
1 parent 97ab580 commit e4263cd

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def __init__(self,
8080
self._tail_mass = float(tail_mass)
8181
if bottleneck_dtype is None:
8282
bottleneck_dtype = tf.keras.mixed_precision.global_policy().compute_dtype
83+
if bottleneck_dtype is None:
84+
bottleneck_dtype = tf.keras.backend.floatx()
8385
self._bottleneck_dtype = tf.as_dtype(bottleneck_dtype)
8486
self._laplace_tail_mass = float(laplace_tail_mass)
8587

0 commit comments

Comments
 (0)