Skip to content

Commit 956d09f

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Fix bug in windowed_mcmc related to constrained distributions having different event_shapes from unconstrained.
PiperOrigin-RevId: 390202596
1 parent b71b6af commit 956d09f

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

tensorflow_probability/python/experimental/mcmc/windowed_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def _get_flat_unconstraining_bijector(jd_model):
168168
event_space_bij = jd_model.experimental_default_event_space_bijector()
169169
flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())
170170

171-
unconstrained_shapes = flat_bijector.inverse_event_shape_tensor(
172-
jd_model.event_shape_tensor())
171+
unconstrained_shapes = event_space_bij(
172+
flat_bijector).inverse_event_shape_tensor(jd_model.event_shape_tensor())
173173

174174
# this reshaping is required as as split can produce a tensor of shape [1]
175175
# when the distribution event shape is []

tensorflow_probability/python/experimental/mcmc/windowed_sampling_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,12 @@ def mk_y(x):
459459
self.assertEqual((2, 64, 10, 3), states['x'].shape)
460460
self.assertEqual((2, 10, 1), trace['step_size'].shape)
461461

462+
def test_bijector(self):
463+
dist = tfd.JointDistributionSequential([tfd.Dirichlet(tf.ones(2))])
464+
bij, _ = windowed_sampling._get_flat_unconstraining_bijector(dist)
465+
draw = dist.sample(seed=test_util.test_seed())
466+
self.assertAllCloseNested(bij.inverse(bij(draw)), draw)
467+
462468

463469
@test_util.test_graph_and_eager_modes
464470
class WindowedSamplingStepSizeTest(test_util.TestCase):

0 commit comments

Comments
 (0)