Skip to content

Commit d5a7005

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Don't require Root in JDC when sampling with a trivial sample shape.
Pros: In JAX world we often wrap large chunks of computation inside a vmap, which can encompass the log_prob/sample calls of a JDC. In this setting a JDC will never see a non-trivial sample shape. Root was introduced to handle non-trivial sample shapes, so in this setting, this annotation is superfluous. Thus, this change removes the requirement for Root when only trivial sample shapes are considered, improving the UX of JDCs in JAX. Cons: The check is deferred to non-trivial sample shapes, meaning that a TF user might construct a malformed distribution and not receive an error until they try to sample with a non-trivial sample shape. Prior to this change, they would get an error as soon as `_get_single_sample_distributions` was called, which happens with most property accesses. Backwards compatibility: Existing well-formed JDCs will continue to work, and will continue being checked for correctness modulo the 'con' above. PiperOrigin-RevId: 385498069
1 parent 1f7bc95 commit d5a7005

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

tensorflow_probability/python/distributions/joint_distribution.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -787,9 +787,12 @@ def _execute_model(self,
787787
gen = self._model_coroutine()
788788
index = 0
789789
d = next(gen)
790-
if self._require_root and not isinstance(d, self.Root):
791-
raise ValueError('First distribution yielded by coroutine must '
792-
'be wrapped in `Root`.')
790+
if self._require_root:
791+
if distribution_util.shape_may_be_nontrivial(
792+
sample_shape) and not isinstance(d, self.Root):
793+
raise ValueError('First distribution yielded by coroutine must '
794+
'be wrapped in `Root` when requesting a nontrivial '
795+
f'sample_shape = {sample_shape}.')
793796
try:
794797
while True:
795798
actual_distribution = d.distribution if isinstance(d, self.Root) else d

tensorflow_probability/python/distributions/joint_distribution_coroutine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ class JointDistributionCoroutine(joint_distribution_lib.JointDistribution):
7878
The distributions that have been wrapped in the
7979
`JointDistributionCoroutine.Root` class will be called with `sample_shape` as
8080
the `sample_shape` argument, and the unwrapped distributions
81-
will be called with `()` as the `sample_shape` argument.
81+
will be called with `()` as the `sample_shape` argument. The `Root` annotation
82+
can be omitted if you never intend to use a `sample_shape` other than `()`.
8283
8384
It is the user's responsibility to ensure that
8485
each of the distributions generates samples with the specified sample

tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def singleton_jdn_model_fn():
8686
@test_util.test_all_tf_execution_regimes
8787
class JointDistributionCoroutineTest(test_util.TestCase):
8888

89-
def test_batch_and_event_shape_no_plate(self):
89+
@parameterized.named_parameters(
90+
('_with_root', True),
91+
('_without_root', False),
92+
)
93+
def test_batch_and_event_shape_no_plate(self, use_root):
9094
# The joint distribution specified below corresponds to this
9195
# graphical model
9296
#
@@ -95,8 +99,10 @@ def test_batch_and_event_shape_no_plate(self):
9599
# \ v
96100
# `->-(c)
97101

102+
root = Root if use_root else lambda x: x
103+
98104
def dist():
99-
a = yield Root(tfd.Bernoulli(probs=0.5,
105+
a = yield root(tfd.Bernoulli(probs=0.5,
100106
dtype=tf.float32))
101107
b = yield tfd.Bernoulli(probs=0.25 + 0.5*a,
102108
dtype=tf.float32)
@@ -135,7 +141,11 @@ def dist():
135141
self.assertAllEqual(batch_shape[1], [])
136142
self.assertAllEqual(batch_shape[2], [])
137143

138-
def test_batch_and_event_shape_with_plate(self):
144+
@parameterized.named_parameters(
145+
('_with_root', True),
146+
('_without_root', False),
147+
)
148+
def test_batch_and_event_shape_with_plate(self, use_root):
139149
# The joint distribution specified below corresponds to this
140150
# graphical model
141151
#
@@ -146,9 +156,11 @@ def test_batch_and_event_shape_with_plate(self):
146156
# (df)--+-->---(x) |
147157
# +--------20-+
148158

159+
root = Root if use_root else lambda x: x
160+
149161
def dist():
150-
g = yield Root(tfd.Gamma(2, 2))
151-
df = yield Root(tfd.Exponential(1.))
162+
g = yield root(tfd.Gamma(2, 2))
163+
df = yield root(tfd.Exponential(1.))
152164
loc = yield tfd.Sample(tfd.Normal(0, g), 20)
153165
yield tfd.Independent(tfd.StudentT(tf.expand_dims(df, -1), loc, 1), 1)
154166

@@ -395,7 +407,7 @@ def dist():
395407
with self.assertRaisesRegexp(
396408
Exception,
397409
'must be wrapped in `Root`'):
398-
self.evaluate(joint.sample(seed=test_util.test_seed()))
410+
self.evaluate(joint.sample(2, seed=test_util.test_seed()))
399411

400412
@parameterized.named_parameters(
401413
('basic', basic_model_with_names_fn),

0 commit comments

Comments
 (0)