Skip to content

Commit 50a1ca9

Browse files
brianwa84jburnim
authored andcommitted
Add another structured state test, then make it pass with a small change to PHMC.
PiperOrigin-RevId: 346322156
1 parent 9bebb90 commit 50a1ca9

File tree

3 files changed

+60
-16
lines changed

3 files changed

+60
-16
lines changed

tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
280280
state_gradients_are_stopped=self.state_gradients_are_stopped)
281281

282282
seed = samplers.sanitize_seed(seed)
283-
current_momentum_parts = momentum_distribution.sample(seed=seed)
283+
current_momentum_parts = list(momentum_distribution.sample(seed=seed))
284284
momentum_log_prob = getattr(momentum_distribution,
285285
'_log_prob_unnormalized',
286286
momentum_distribution.log_prob)

tensorflow_probability/python/mcmc/internal/util.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def make_name(super_name, default_super_name, sub_name):
144144
def _choose_base_case(is_accepted,
145145
proposed,
146146
current,
147-
name=None):
147+
name=None,
148+
addr=None,):
148149
"""Helper to `choose` which expand_dims `is_accepted` and applies tf.where."""
149150
def _where(proposed, current):
150151
"""Wraps `tf.where`."""
@@ -162,30 +163,38 @@ def _where(proposed, current):
162163
with tf.name_scope(name or 'choose'):
163164
if not is_list_like(proposed):
164165
return _where(proposed, current)
165-
return [(choose(is_accepted, p, c, name=name) if is_namedtuple_like(p)
166-
else _where(p, c))
167-
for p, c in zip(proposed, current)]
166+
return tf.nest.pack_sequence_as(
167+
current,
168+
[(_choose_recursive(is_accepted, p, c, name=name, addr=f'{addr}[i]')
169+
if is_namedtuple_like(p) else
170+
_where(p, c)) for i, (p, c) in enumerate(zip(proposed, current))])
168171

169172

170-
def choose(is_accepted, proposed, current, name=None):
171-
"""Helper which expand_dims `is_accepted` then applies tf.where."""
173+
def _choose_recursive(is_accepted, proposed, current, name=None, addr='<root>'):
174+
"""Recursion helper which also reports the address of any failures."""
172175
with tf.name_scope(name or 'choose'):
173176
if not is_namedtuple_like(proposed):
174-
return _choose_base_case(is_accepted, proposed, current, name=name)
177+
return _choose_base_case(is_accepted, proposed, current, name=name,
178+
addr=addr)
175179
if not isinstance(proposed, type(current)):
176-
raise TypeError('Type of `proposed` ({}) must be identical to '
177-
'type of `current` ({})'.format(
178-
type(proposed).__name__,
179-
type(current).__name__))
180+
raise TypeError(
181+
f'Type of `proposed` ({type(proposed).__name__}) must be identical '
182+
f'to type of `current` ({type(current).__name__}). (At "{addr}".)')
180183
items = {}
181184
for fn in proposed._fields:
182-
items[fn] = choose(is_accepted,
183-
getattr(proposed, fn),
184-
getattr(current, fn),
185-
name=name)
185+
items[fn] = _choose_recursive(is_accepted,
186+
getattr(proposed, fn),
187+
getattr(current, fn),
188+
name=name,
189+
addr=f'{addr}/{fn}')
186190
return type(proposed)(**items)
187191

188192

193+
def choose(is_accepted, proposed, current, name=None):
194+
"""Helper which expand_dims `is_accepted` then applies tf.where."""
195+
return _choose_recursive(is_accepted, proposed, current, name=name)
196+
197+
189198
def strip_seeds(obj):
190199
if not is_namedtuple_like(obj):
191200
return obj

tensorflow_probability/python/mcmc/sample_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensorflow_probability.python.internal import tensorshape_util
3232
from tensorflow_probability.python.internal import test_util
3333

34+
tfb = tfp.bijectors
3435
tfd = tfp.distributions
3536

3637

@@ -404,6 +405,40 @@ def sample():
404405
seed=seed_stream())
405406
self.evaluate(sample())
406407

408+
@test_util.jax_disable_test_missing_functionality('PHMC b/175107050')
409+
@test_util.numpy_disable_gradient_test('HMC')
410+
def testStructuredState2(self):
411+
@tfd.JointDistributionCoroutineAutoBatched
412+
def model():
413+
mu = yield tfd.Sample(tfd.Normal(0, 1), [65], name='mu')
414+
sigma = yield tfd.Sample(tfd.Exponential(1.), [65], name='sigma')
415+
beta = yield tfd.Sample(
416+
tfd.Normal(loc=tf.gather(mu, tf.range(436) % 65, axis=-1),
417+
scale=tf.gather(sigma, tf.range(436) % 65, axis=-1)),
418+
4, name='beta')
419+
_ = yield tfd.Multinomial(total_count=100.,
420+
logits=tfb.Pad([[0, 1]])(beta),
421+
name='y')
422+
423+
stream = test_util.test_seed_stream()
424+
pinned = model.experimental_pin(y=model.sample(seed=stream()).y)
425+
struct = pinned.dtype
426+
stddevs = struct._make([
427+
tf.fill([65], .1), tf.fill([65], 1.), tf.fill([436, 4], 10.)])
428+
momentum_dist = tfd.JointDistributionNamedAutoBatched(
429+
struct._make(tfd.Normal(0, 1 / std) for std in stddevs))
430+
kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
431+
pinned.unnormalized_log_prob,
432+
step_size=.1, num_leapfrog_steps=10,
433+
momentum_distribution=momentum_dist)
434+
bijector = pinned.experimental_default_event_space_bijector()
435+
kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector)
436+
state = bijector(struct._make(
437+
tfd.Uniform(-2., 2.).sample(shp)
438+
for shp in bijector.inverse_event_shape(pinned.event_shape)))
439+
self.evaluate(tfp.mcmc.sample_chain(
440+
3, current_state=state, kernel=kernel, seed=stream()))
441+
407442

408443
if __name__ == '__main__':
409444
tf.test.main()

0 commit comments

Comments
 (0)