Skip to content

Commit 2a8ede2

Browse files
Johannes Ballécopybara-github
authored andcommitted
Small edits to distributions/entropy models.
- Lower bounds the probabilities computed in DeepFactorized by zero. - Adds a name scope to the variables created in DeepFactorized. - Enables deleting the prior distribution from the entropy model. - Removes deprecated backprop argument from tf.while. PiperOrigin-RevId: 323400592 Change-Id: Iff3cb1cb3f83fb7381798e6f71779126ac7a00da
1 parent cb1dee3 commit 2a8ede2

File tree

4 files changed

+14
-5
lines changed

4 files changed

+14
-5
lines changed

tensorflow_compression/python/distributions/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ py_library(
1919
name = "deep_factorized",
2020
srcs = ["deep_factorized.py"],
2121
srcs_version = "PY3",
22-
deps = [":helpers"],
22+
deps = [
23+
":helpers",
24+
"//tensorflow_compression/python/ops:math_ops",
25+
],
2326
)
2427

2528
py_test(

tensorflow_compression/python/distributions/deep_factorized.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tensorflow_probability as tfp
2020

2121
from tensorflow_compression.python.distributions import helpers
22+
from tensorflow_compression.python.ops import math_ops
2223

2324

2425
__all__ = ["DeepFactorized"]
@@ -77,7 +78,8 @@ def __init__(self, batch_shape=(), num_filters=(3, 3), init_scale=10,
7778
parameters=parameters,
7879
name=name,
7980
)
80-
self._make_variables()
81+
with self.name_scope:
82+
self._make_variables()
8183

8284
@property
8385
def num_filters(self):
@@ -168,6 +170,7 @@ def _prob(self, y):
168170
# Flip signs if we can move more towards the left tail of the sigmoid.
169171
sign = tf.stop_gradient(-tf.math.sign(lower + upper))
170172
p = abs(tf.sigmoid(sign * upper) - tf.sigmoid(sign * lower))
173+
p = math_ops.lower_bound(p, 0.)
171174

172175
# Convert back to (broadcasted) input tensor shape.
173176
p = tf.transpose(p, (2, 1, 0))

tensorflow_compression/python/distributions/helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def loop_body(tails, m, v, count):
8282
init_v = tf.ones(shape, dtype=dtype)
8383
init_count = tf.zeros(shape, dtype=tf.int32)
8484
return tf.while_loop(
85-
loop_cond, loop_body, (init_tails, init_m, init_v, init_count),
86-
back_prop=False)[0]
85+
loop_cond, loop_body, (init_tails, init_m, init_v, init_count))[0]
8786

8887

8988
def quantization_offset(distribution):

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,13 @@ def prior(self):
8888
raise RuntimeError(
8989
"This entropy model doesn't hold a reference to its prior "
9090
"distribution. This can happen when it is unserialized, because "
91-
"the prior is generally not serializable.")
91+
"the prior is not generally serializable.")
9292
return self._prior
9393

94+
@prior.deleter
95+
def prior(self):
96+
del self._prior
97+
9498
def _check_compression(self):
9599
if not self.compression:
96100
raise RuntimeError(

0 commit comments

Comments
 (0)