Skip to content

Commit 4bdeb05

Browse files
brianwa84tensorflower-gardener
authored andcommitted
Update distribution_bijectors_test for non-list MCMC state.
PiperOrigin-RevId: 374723488
1 parent e838d8b commit 4bdeb05

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,12 @@ def model():
144144
'build_asvi_surrogate_posterior')
145145
def test_mcmc_funnel_docstring_example_runs(self):
146146

147-
# TODO(b/170865194): Use JDC here once sample_chain can take non-list state.
148-
model_with_funnel = tfd.JointDistributionSequentialAutoBatched([
149-
tfd.Normal(loc=-1., scale=2., name='z'),
150-
lambda z: tfd.Normal(loc=[0.], scale=tf.exp(z), name='x'),
151-
lambda x: tfd.Poisson(log_rate=x, name='y')])
152-
pinned_model = tfp.experimental.distributions.JointDistributionPinned(
153-
model_with_funnel, y=[1])
147+
@tfd.JointDistributionCoroutineAutoBatched
148+
def model_with_funnel():
149+
z = yield tfd.Normal(loc=-1., scale=2., name='z')
150+
x = yield tfd.Normal(loc=[0.], scale=tf.exp(z), name='x')
151+
yield tfd.Poisson(log_rate=x, name='y')
152+
pinned_model = model_with_funnel.experimental_pin(y=[1])
154153
surrogate_posterior = tfp.experimental.vi.build_asvi_surrogate_posterior(
155154
pinned_model)
156155

0 commit comments

Comments
 (0)