Skip to content

Commit c1c8f74

Browse files
sharadmvtensorflower-gardener
authored andcommitted
Add batch axis names for gradient-based trajectory length adaptation
PiperOrigin-RevId: 378240709
1 parent c6e4fab commit c1c8f74

File tree

2 files changed

+98
-15
lines changed

2 files changed

+98
-15
lines changed

tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Gradient-based trajectory length adaptation kernel."""
1616

1717
import collections
18+
import functools
1819

1920
import tensorflow.compat.v2 as tf
2021

@@ -89,6 +90,16 @@ def _map_structure_up_to_with_axes(structure, fn, *args,
8990
experimental_shard_axis_names)
9091

9192

93+
def _reduce_with_axes(index_op, name_op, x, axis_idx=None, axis_names=None):
94+
return name_op(index_op(x, axis_idx), axis_names)
95+
96+
97+
_reduce_sum_with_axes = functools.partial(_reduce_with_axes, tf.reduce_sum,
98+
distribute_lib.psum)
99+
_reduce_mean_with_axes = functools.partial(_reduce_with_axes, tf.reduce_mean,
100+
distribute_lib.pmean)
101+
102+
92103
def hmc_like_num_leapfrog_steps_getter_fn(kernel_results):
93104
"""Getter for `num_leapfrog_steps` so it can be inspected."""
94105
return unnest.get_innermost(kernel_results, 'num_leapfrog_steps')
@@ -132,7 +143,8 @@ def chees_criterion(previous_state,
132143
proposed_state,
133144
accept_prob,
134145
validate_args=False,
135-
experimental_shard_axis_names=None):
146+
experimental_shard_axis_names=None,
147+
experimental_chain_axis_names=None):
136148
"""The ChEES criterion from [1].
137149
138150
ChEES stands for Change in the Estimator of the Expected Square.
@@ -166,6 +178,8 @@ def chees_criterion(previous_state,
166178
validate_args: Whether to perform non-static argument validation.
167179
experimental_shard_axis_names: A structure of string names indicating how
168180
members of the state are sharded.
181+
experimental_chain_axis_names: A string or list of string names indicating
182+
how batches of chains are sharded.
169183
170184
Returns:
171185
chees: The value of the ChEES criterion.
@@ -182,7 +196,13 @@ def chees_criterion(previous_state,
182196
"""
183197
batch_ndims = ps.rank(accept_prob)
184198
batch_axes = ps.range(batch_ndims, dtype=tf.int32)
185-
num_chains = ps.size(accept_prob)
199+
experimental_chain_axis_names = distribute_lib.canonicalize_axis_name(
200+
experimental_chain_axis_names)
201+
# Number of total chains is local batch size * distributed axis size
202+
local_axis_size = ps.maximum(ps.size(accept_prob), 1)
203+
distributed_axis_size = int(ps.reduce_prod([
204+
distribute_lib.get_axis_size(a) for a in experimental_chain_axis_names]))
205+
num_chains = local_axis_size * distributed_axis_size
186206
num_chains_ = tf.get_static_value(num_chains)
187207
if num_chains_ is not None:
188208
if num_chains_ < 2:
@@ -199,7 +219,9 @@ def chees_criterion(previous_state,
199219
def _center_previous_state(x):
200220
# The empirical mean here is a stand-in for the true mean, so we drop the
201221
# gradient that flows through this term.
202-
return x - tf.stop_gradient(tf.reduce_mean(x, axis=batch_axes))
222+
x_mean = _reduce_mean_with_axes(
223+
x, batch_axes, experimental_chain_axis_names)
224+
return x - tf.stop_gradient(x_mean)
203225

204226
def _center_proposed_state(x):
205227
# The empirical mean here is a stand-in for the true mean, so we drop the
@@ -216,8 +238,10 @@ def _center_proposed_state(x):
216238
# If all accept_prob's are zero, the x_center will have a nonsense value,
217239
# but we'll discard the resultant gradients later on, so it's fine.
218240
x_center = (
219-
tf.reduce_sum(expanded_accept_prob * x_safe, axis=batch_axes) /
220-
(tf.reduce_sum(expanded_accept_prob, axis=batch_axes) + 1e-20))
241+
_reduce_sum_with_axes(expanded_accept_prob * x_safe, batch_axes,
242+
experimental_chain_axis_names) /
243+
(_reduce_sum_with_axes(expanded_accept_prob, batch_axes,
244+
experimental_chain_axis_names) + 1e-20))
221245

222246
return x - tf.stop_gradient(x_center)
223247

@@ -358,6 +382,7 @@ def __init__(
358382
proposed_state_getter_fn=hmc_like_proposed_state_getter_fn,
359383
validate_args=False,
360384
experimental_shard_axis_names=None,
385+
experimental_chain_axis_names=None,
361386
name=None):
362387
"""Creates the trajectory length adaptation kernel.
363388
@@ -414,6 +439,8 @@ def __init__(
414439
outputs.
415440
experimental_shard_axis_names: A structure of string names indicating how
416441
members of the state are sharded.
442+
experimental_chain_axis_names: A string or list of string names indicating
443+
how batches of chains are sharded.
417444
name: Python `str` name prefixed to Ops created by this class. Default:
418445
'simple_step_size_adaptation'.
419446
@@ -452,6 +479,7 @@ class docstring).
452479
proposed_state_getter_fn=hmc_like_proposed_state_getter_fn,
453480
validate_args=validate_args,
454481
experimental_shard_axis_names=experimental_shard_axis_names,
482+
experimental_chain_axis_names=experimental_chain_axis_names,
455483
name=name,
456484
)
457485

@@ -468,12 +496,15 @@ def num_adaptation_steps(self):
468496
return self._parameters['num_adaptation_steps']
469497

470498
def criterion_fn(self, previous_state, proposed_state, accept_prob):
471-
if self.experimental_shard_axis_names is None:
472-
return self._parameters['criterion_fn'](previous_state, proposed_state,
473-
accept_prob)
474-
return self._parameters['criterion_fn'](
475-
previous_state, proposed_state, accept_prob,
476-
experimental_shard_axis_names=self.experimental_shard_axis_names)
499+
kwargs = {}
500+
if self.experimental_chain_axis_names is not None:
501+
kwargs['experimental_chain_axis_names'] = (
502+
self.experimental_chain_axis_names)
503+
if self.experimental_shard_axis_names is not None:
504+
kwargs['experimental_shard_axis_names'] = (
505+
self.experimental_shard_axis_names)
506+
return self._parameters['criterion_fn'](previous_state, proposed_state,
507+
accept_prob, **kwargs)
477508

478509
@property
479510
def max_leapfrog_steps(self):
@@ -567,7 +598,8 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
567598
step_size=step_size,
568599
criterion_fn=self.criterion_fn,
569600
max_leapfrog_steps=self.max_leapfrog_steps,
570-
experimental_shard_axis_names=self.experimental_shard_axis_names)
601+
experimental_shard_axis_names=self.experimental_shard_axis_names,
602+
experimental_chain_axis_names=self.experimental_chain_axis_names)
571603

572604
# Undo the effect of adaptation if we're not in the burnin phase. We keep
573605
# the criterion, however, as that's a diagnostic. We also keep the
@@ -623,9 +655,16 @@ def is_calibrated(self):
623655
def experimental_shard_axis_names(self):
624656
return self._parameters['experimental_shard_axis_names']
625657

658+
@property
659+
def experimental_chain_axis_names(self):
660+
return self._parameters['experimental_chain_axis_names']
661+
626662
def experimental_with_shard_axes(self, shard_axis_names):
627663
return self.copy(experimental_shard_axis_names=shard_axis_names)
628664

665+
def experimental_with_chain_axes(self, chain_axis_names):
666+
return self.copy(experimental_chain_axis_names=chain_axis_names)
667+
629668

630669
def _forbid_inner_transformed_kernel(inner_kernel):
631670
"""Forbids inner kernel from containing `TransformedTransitionKernel`."""
@@ -669,7 +708,8 @@ def _update_trajectory_grad(previous_kernel_results, previous_state,
669708
proposed_state, proposed_velocity,
670709
trajectory_jitter, accept_prob, step_size,
671710
criterion_fn, max_leapfrog_steps,
672-
experimental_shard_axis_names=None):
711+
experimental_shard_axis_names=None,
712+
experimental_chain_axis_names=None):
673713
"""Updates the trajectory length."""
674714
# Compute criterion grads.
675715
def leapfrog_action(dt):
@@ -693,12 +733,16 @@ def adjust_state(x, v, shard_axes=None):
693733
trajectory_grad *= trajectory_jitter
694734

695735
# Weight by acceptance probability.
736+
experimental_chain_axis_names = distribute_lib.canonicalize_axis_name(
737+
experimental_chain_axis_names)
696738
trajectory_grad = tf.where(accept_prob > 1e-4, trajectory_grad, 0.)
697739
trajectory_grad = tf.where(
698740
tf.math.is_finite(trajectory_grad), trajectory_grad, 0.)
699741
trajectory_grad = (
700-
tf.reduce_sum(trajectory_grad * accept_prob) /
701-
tf.reduce_sum(accept_prob + 1e-20))
742+
_reduce_sum_with_axes(trajectory_grad * accept_prob,
743+
None, experimental_chain_axis_names) /
744+
_reduce_sum_with_axes(accept_prob + 1e-20, None,
745+
experimental_chain_axis_names))
702746

703747
# Compute Adam/RMSProp step size.
704748
dtype = previous_kernel_results.adaptation_rate.dtype

tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,45 @@ def run(seed):
396396
self.assertAllClose(avg_sq_grad[0], avg_sq_grad[i])
397397
self.assertAllClose(avg_max_tl[0], avg_max_tl[i])
398398

399+
def test_gbtla_kernel_can_shard_chains_across_devices(self):
400+
401+
def target_log_prob(a, b):
402+
return (
403+
tfd.Normal(0., 1.).log_prob(a)
404+
+ tfd.Sample(tfd.Normal(a, 1.), 4).log_prob(b))
405+
406+
kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob,
407+
step_size=1e-2,
408+
num_leapfrog_steps=2)
409+
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
410+
kernel, 10)
411+
sharded_kernel = kernel.experimental_with_chain_axes(self.axis_name)
412+
413+
def run(seed):
414+
init_seed, sample_seed = samplers.split_seed(seed)
415+
state_seeds = samplers.split_seed(init_seed)
416+
state = [
417+
samplers.normal(seed=state_seeds[0], shape=[]),
418+
samplers.normal(seed=state_seeds[1], shape=[4])
419+
]
420+
kr = sharded_kernel.bootstrap_results(state)
421+
_, kr = sharded_kernel.one_step(state, kr, seed=sample_seed)
422+
return (
423+
kr.averaged_sq_grad,
424+
kr.averaged_max_trajectory_length
425+
)
426+
427+
seeds = self.shard_values(tf.stack(tfp.random.split_seed(
428+
samplers.zeros_seed(), distribute_test_lib.NUM_DEVICES)), 0)
429+
430+
avg_sq_grad, avg_max_tl = self.evaluate(
431+
self.per_replica_to_tensor(self.strategy_run(
432+
run, args=(seeds,), axis_name=self.axis_name), 0))
433+
434+
for i in range(distribute_test_lib.NUM_DEVICES):
435+
self.assertAllClose(avg_sq_grad[0], avg_sq_grad[i])
436+
self.assertAllClose(avg_max_tl[0], avg_max_tl[i])
437+
399438

400439
del _GradientBasedTrajectoryLengthAdaptationTest
401440

0 commit comments

Comments
 (0)