Skip to content

Commit f572d16

Browse files
davmretensorflower-gardener
authored andcommitted
Avoid sampling from JDSequential/JDNamed models at init time.
Now that we use static tracing for most JD properties, it makes sense to defer the cost of building and caching a list of distributions until some method actually needs them. PiperOrigin-RevId: 385001342
1 parent 9afe430 commit f572d16

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

tensorflow_probability/python/distributions/joint_distribution_sequential.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,12 @@ def __init__(self, model, validate_args=False, name=None):
210210
self._build(model)
211211

212212
self._single_sample_distributions = {}
213-
self._get_single_sample_distributions(candidate_dists=[
214-
None if a else d() for d, a
215-
in zip(self._dist_fn_wrapped, self._dist_fn_args)])
213+
214+
# If the model consists entirely of prebuilt distributions with no
215+
# dependencies, cache them directly to avoid a sample call down the road.
216+
if not any(self._dist_fn_args):
217+
self._get_single_sample_distributions(
218+
candidate_dists=[d() for d in self._dist_fn_wrapped])
216219

217220
super(JointDistributionSequential, self).__init__(
218221
dtype=None, # Ignored; we'll override.

tensorflow_probability/python/distributions/joint_distribution_sequential_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def test_norequired_args_maker(self):
176176
"""Test that only non-default args are passed through."""
177177
with self.assertRaisesWithPredicateMatch(
178178
ValueError, 'Must pass probs or logits, but not both.'):
179-
tfd.JointDistributionSequential([tfd.Normal(0., 1.), tfd.Bernoulli])
179+
d = tfd.JointDistributionSequential([tfd.Normal(0., 1.), tfd.Bernoulli])
180+
d.sample(seed=test_util.test_seed())
180181

181182
def test_graph_resolution(self):
182183
d = tfd.JointDistributionSequential(
@@ -767,6 +768,28 @@ def test_creates_valid_coroutine(self):
767768
jdc.sample([5], seed=test_util.test_seed()))]
768769
self.assertAllEqualNested(sample_shapes, jdc_sample_shapes)
769770

771+
def test_init_does_not_execute_model(self):
772+
model_traces = []
773+
def record_model_called(x):
774+
model_traces.append(x)
775+
return x
776+
777+
model = tfd.JointDistributionSequential(
778+
[
779+
tfd.Normal(0., 1.),
780+
lambda z: tfd.Normal(record_model_called(z), 1.)
781+
],
782+
validate_args=True)
783+
# Model should not be called from init.
784+
self.assertLen(model_traces, 0)
785+
model.sample(seed=test_util.test_seed())
786+
# The first sample call will run the model twice (for shape
787+
# inference + actually sampling).
788+
self.assertLen(model_traces, 2)
789+
# Subsequent calls should only run the model once.
790+
model.sample([2], seed=test_util.test_seed())
791+
self.assertLen(model_traces, 3)
792+
770793

771794
class ResolveDistributionNamesTest(test_util.TestCase):
772795

0 commit comments

Comments
 (0)