Skip to content

Commit b33f130

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
FunMC: Add explicit state named axis support, for convenience.
Also: - Fix some collections ABC deprecations - Fix the chain_ndims being off by 1 in the maximal coupling proposal PiperOrigin-RevId: 385606417
1 parent d5a7005 commit b33f130

File tree

4 files changed

+288
-68
lines changed

4 files changed

+288
-68
lines changed

spinoffs/fun_mc/fun_mc/fun_mc_lib.py

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,10 @@ def call_fn(
310310
Returns:
311311
ret: Return value of `fn`.
312312
"""
313-
if isinstance(args, collections.Sequence) and not _is_namedtuple_like(args):
313+
if (isinstance(args, collections.abc.Sequence) and
314+
not _is_namedtuple_like(args)):
314315
return fn(*args)
315-
elif isinstance(args, collections.Mapping):
316+
elif isinstance(args, collections.abc.Mapping):
316317
return fn(**args)
317318
else:
318319
return fn(args)
@@ -393,7 +394,7 @@ def call_potential_fn(
393394
'A common solution is to adjust the `return`s in `fn` to '
394395
'be `return args, ()`.')
395396

396-
if not isinstance(ret, collections.Sequence) or len(ret) != 2:
397+
if not isinstance(ret, collections.abc.Sequence) or len(ret) != 2:
397398
args_s = _tree_repr(args)
398399
ret_s = _tree_repr(ret)
399400
raise TypeError(
@@ -434,7 +435,7 @@ def call_transition_operator(
434435
'A common solution is to adjust the `return`s in `fn` to '
435436
'be `return args, ()`.')
436437

437-
if not isinstance(ret, collections.Sequence) or len(ret) != 2:
438+
if not isinstance(ret, collections.abc.Sequence) or len(ret) != 2:
438439
args_s = _tree_repr(args)
439440
ret_s = _tree_repr(ret)
440441
raise TypeError(
@@ -1185,6 +1186,7 @@ def persistent_metropolis_hastings_step(
11851186
def gaussian_momentum_sample(state: 'Optional[State]' = None,
11861187
shape: 'Optional[IntTensor]' = None,
11871188
dtype: 'Optional[DTypeNest]' = None,
1189+
named_axis: 'Optional[StringNest]' = None,
11881190
seed=None) -> 'State':
11891191
"""Generates a sample from a Gaussian (Normal) momentum distribution.
11901192
@@ -1197,6 +1199,7 @@ def gaussian_momentum_sample(state: 'Optional[State]' = None,
11971199
output.
11981200
shape: A nest of shapes, which matches the output shapes.
11991201
dtype: A nest of dtypes, which matches the output dtypes.
1202+
named_axis: Named axes of the state, same structure as `state`.
12001203
seed: For reproducibility.
12011204
12021205
Returns:
@@ -1206,24 +1209,30 @@ def gaussian_momentum_sample(state: 'Optional[State]' = None,
12061209
if dtype is None or shape is None:
12071210
shape = util.map_tree(lambda t: t.shape, state)
12081211
dtype = util.map_tree(lambda t: t.dtype, state)
1212+
if named_axis is None:
1213+
named_axis = util.map_tree(lambda _: [], dtype)
12091214

12101215
num_seeds_needed = len(util.flatten_tree(dtype))
12111216
seeds = list(util.split_seed(seed, num_seeds_needed))
12121217
seeds = util.unflatten_tree(dtype, seeds)
12131218

1214-
def _one_part(dtype, shape, seed):
1219+
def _one_part(dtype, shape, seed, named_axis):
1220+
seed = backend.distribute_lib.fold_in_axis_index(seed, named_axis)
12151221
return util.random_normal(shape=shape, dtype=dtype, seed=seed)
12161222

1217-
return util.map_tree_up_to(dtype, _one_part, dtype, shape, seeds)
1223+
return util.map_tree_up_to(dtype, _one_part, dtype, shape, seeds, named_axis)
12181224

12191225

12201226
def make_gaussian_kinetic_energy_fn(
1221-
chain_ndims: 'IntTensor') -> 'Callable[..., Tuple[tf.Tensor, TensorNest]]':
1227+
chain_ndims: 'IntTensor',
1228+
named_axis: 'Optional[StringNest]' = None,
1229+
) -> 'Callable[..., Tuple[tf.Tensor, TensorNest]]':
12221230
"""Returns a function that computes the kinetic energy of a state.
12231231
12241232
Args:
12251233
chain_ndims: How many leading dimensions correspond to independent
12261234
particles.
1235+
named_axis: Named axes of the state, same structure as `state`.
12271236
12281237
Returns:
12291238
kinetic_energy_fn: A callable that takes in the expanded state (see
@@ -1233,13 +1242,29 @@ def make_gaussian_kinetic_energy_fn(
12331242

12341243
@util.named_call
12351244
def kinetic_energy_fn(*args, **kwargs):
1245+
state_args = (args, kwargs)
12361246

1237-
def one_component(x):
1238-
return tf.reduce_sum(
1239-
tf.square(x), axis=tuple(range(chain_ndims, len(x.shape))))
1240-
1241-
return (tf.add_n(
1242-
[one_component(x) for x in util.flatten_tree([args, kwargs])]) / 2.), ()
1247+
if named_axis is None:
1248+
named_axis_args = util.map_tree(lambda _: [], state_args)
1249+
else:
1250+
# We need named_axis to line up with state, but state has been decomposed
1251+
# into args, kwargs via call_fn which called this function. Normally, we'd
1252+
# reconstruct the state via recover_state_from_args, but we don't have a
1253+
# good reference structure (named_axis is no good as it can have tuples as
1254+
# leaves). Instead, we go the other way, and decompose named_axis into
1255+
# args, kwargs. These new objects are guaranteed to line up with the
1256+
# decomposed state.
1257+
named_axis_args = call_fn(lambda *args, **kwargs: (args, kwargs),
1258+
named_axis)
1259+
1260+
def _one_part(x, named_axis):
1261+
return backend.distribute_lib.reduce_sum(
1262+
tf.square(x), tuple(range(chain_ndims, len(x.shape))), named_axis)
1263+
1264+
return 0.5 * sum(
1265+
util.flatten_tree(
1266+
util.map_tree_up_to(state_args, _one_part, state_args,
1267+
named_axis_args))), ()
12431268

12441269
return kinetic_energy_fn
12451270

@@ -1315,6 +1340,7 @@ def hamiltonian_monte_carlo_step(
13151340
energy_change_fn:
13161341
'Callable[[IntegratorState, IntegratorState, IntegratorExtras], '
13171342
'Tuple[FloatTensor, Any]]' = _default_hamiltonian_monte_carlo_energy_change_fn,
1343+
named_axis: 'Optional[StringNest]' = None,
13181344
seed=None,
13191345
) -> 'Tuple[HamiltonianMonteCarloState, HamiltonianMonteCarloExtra]':
13201346
"""Hamiltonian Monte Carlo `TransitionOperator`.
@@ -1381,6 +1407,7 @@ def orig_target_log_prob_fn(x):
13811407
Computes the change in energy between current and proposed states. By
13821408
default, it just substracts the current and proposed energies. A typical
13831409
reason to override this is to improve numerical stability.
1410+
named_axis: Named axes of the state, same structure as `hmc_state.state`.
13841411
seed: For reproducibility.
13851412
13861413
Returns:
@@ -1395,11 +1422,11 @@ def orig_target_log_prob_fn(x):
13951422
if kinetic_energy_fn is None:
13961423
kinetic_energy_fn = make_gaussian_kinetic_energy_fn(
13971424
len(target_log_prob.shape) if target_log_prob.shape is not None else tf # pytype: disable=attribute-error
1398-
.rank(target_log_prob))
1425+
.rank(target_log_prob), named_axis=named_axis)
13991426

14001427
if momentum_sample_fn is None:
14011428
momentum_sample_fn = lambda seed: gaussian_momentum_sample( # pylint: disable=g-long-lambda
1402-
state=state, seed=seed)
1429+
state=state, seed=seed, named_axis=named_axis)
14031430

14041431
if integrator_fn is None:
14051432
integrator_fn = lambda state: hamiltonian_integrator( # pylint: disable=g-long-lambda
@@ -1817,12 +1844,14 @@ def _one_part(state, g, learning_rate):
18171844
def gaussian_proposal(
18181845
state: 'State',
18191846
scale: 'FloatNest' = 1.,
1847+
named_axis: 'Optional[StringNest]' = None,
18201848
seed: 'Optional[Any]' = None) -> 'Tuple[State, Tuple[Tuple[()], float]]':
18211849
"""Axis-aligned gaussian random-walk proposal.
18221850
18231851
Args:
18241852
state: Current state.
18251853
scale: Scale of the proposal.
1854+
named_axis: Named axes of the state, same structure as `state`.
18261855
seed: Random seed.
18271856
18281857
Returns:
@@ -1832,13 +1861,16 @@ def gaussian_proposal(
18321861
scale = maybe_broadcast_structure(scale, state)
18331862
num_parts = len(util.flatten_tree(state))
18341863
seeds = util.unflatten_tree(state, util.split_seed(seed, num_parts))
1864+
if named_axis is None:
1865+
named_axis = util.map_tree(lambda _: [], state)
18351866

1836-
new_state = util.map_tree(
1837-
lambda x, scale, seed: x + scale * util.random_normal( # pylint: disable=g-long-lambda
1838-
x.shape, x.dtype, seed),
1839-
state,
1840-
scale,
1841-
seeds)
1867+
def _sample_part(x, scale, seed, named_axis):
1868+
seed = backend.distribute_lib.fold_in_axis_index(seed, named_axis)
1869+
return x + scale * util.random_normal( # pylint: disable=g-long-lambda
1870+
x.shape, x.dtype, seed)
1871+
1872+
new_state = util.map_tree_up_to(state, _sample_part, state, scale, seeds,
1873+
named_axis)
18421874

18431875
return new_state, ((), 0.)
18441876

@@ -1854,6 +1886,7 @@ def maximal_reflection_coupling_proposal(
18541886
state: 'State',
18551887
chain_ndims: 'int' = 0,
18561888
scale: 'FloatNest' = 1,
1889+
named_axis: 'Optional[StringNest]' = None,
18571890
epsilon: 'FloatTensor' = 1e-20,
18581891
seed: 'Optional[Any]' = None
18591892
) -> 'Tuple[State, Tuple[MaximalReflectiveCouplingProposalExtra, float]]':
@@ -1869,11 +1902,15 @@ def maximal_reflection_coupling_proposal(
18691902
dimension such that `chain_i` is coupled with `chain_i + num_chains`, where
18701903
`num_chains = state.shape[0] // 2`
18711904
1905+
This function supports SPMD via sharded states in the same sense as TensorFlow
1906+
Probability's `tfp.experimental.distribute.Sharded`.
1907+
18721908
Args:
18731909
state: Current state of the two sets of chains.
18741910
chain_ndims: How many leading dimensions correspond to independent chains
18751911
(not counting the first one).
18761912
scale: Scale of the proposal.
1913+
named_axis: Shard axes names, used for SPMD.
18771914
epsilon: Small offset for numerical stability.
18781915
seed: Random seed.
18791916
@@ -1887,51 +1924,60 @@ def maximal_reflection_coupling_proposal(
18871924
Retrieved from http://arxiv.org/abs/2102.01790
18881925
"""
18891926

1927+
_sum = backend.distribute_lib.reduce_sum # pylint: disable=invalid-name
1928+
18901929
def _struct_sum(s):
18911930
return sum(util.flatten_tree(s))
18921931

1932+
if named_axis is None:
1933+
named_axis = util.map_tree(lambda _: [], state)
18931934
scale = maybe_broadcast_structure(scale, state)
18941935
num_chains = util.flatten_tree(state)[0].shape[0] // 2
18951936
mu1 = util.map_tree(lambda x: x[:num_chains], state)
18961937
mu2 = util.map_tree(lambda x: x[num_chains:], state)
18971938
event_dims = util.map_tree(
1898-
lambda x: tuple(range(chain_ndims, len(x.shape))), # pylint: disable=g-long-lambda
1939+
lambda x: tuple(range(1 + chain_ndims, len(x.shape))),
18991940
mu1)
19001941
z = util.map_tree(lambda s, x1, x2: (x1 - x2) / s, scale, mu1, mu2)
19011942
z_norm = tf.sqrt(
19021943
_struct_sum(
1903-
util.map_tree_up_to(z, lambda z, ed: tf.reduce_sum(tf.square(z), ed),
1904-
z, event_dims)))
1944+
util.map_tree_up_to(z, lambda z, ed, na: _sum(tf.square(z), ed, na),
1945+
z, event_dims, named_axis)))
19051946
e = util.map_tree(
19061947
lambda z: z / # pylint: disable=g-long-lambda
19071948
(tf.reshape(z_norm, z_norm.shape + (1,) *
19081949
(len(z.shape) - len(z_norm.shape))) + epsilon),
19091950
z)
1910-
batch_shape = util.flatten_tree(mu1)[0].shape[:chain_ndims]
1951+
batch_shape = util.flatten_tree(mu1)[0].shape[1:1 + chain_ndims]
19111952

19121953
num_parts = len(util.flatten_tree(state))
19131954
all_seeds = util.split_seed(seed, num_parts + 1)
19141955
x_seeds = util.unflatten_tree(state, all_seeds[:num_parts])
19151956
couple_seed = all_seeds[-1]
19161957

1917-
x = util.map_tree(lambda x, seed: util.random_normal(x.shape, x.dtype, seed),
1918-
mu1, x_seeds)
1958+
def _sample_part(x, seed, named_axis):
1959+
seed = backend.distribute_lib.fold_in_axis_index(seed, named_axis)
1960+
return util.random_normal(x.shape, x.dtype, seed)
1961+
1962+
x = util.map_tree_up_to(mu1, _sample_part, mu1, x_seeds, named_axis)
19191963

19201964
e_dot_x = _struct_sum(
19211965
util.map_tree_up_to(
19221966
x,
1923-
lambda x, e, ed: tf.reduce_sum(x * e, ed), # pylint: disable=g-long-lambda
1967+
lambda x, e, ed, na: _sum(x * e, ed, na),
19241968
x,
19251969
e,
1926-
event_dims))
1970+
event_dims,
1971+
named_axis))
19271972

19281973
log_couple_ratio = _struct_sum(
19291974
util.map_tree_up_to(
19301975
x,
1931-
lambda x, z, ed: -tf.reduce_sum(x * z + tf.square(z) / 2, ed), # pylint: disable=g-long-lambda
1976+
lambda x, z, ed, na: -_sum(x * z + tf.square(z) / 2, ed, na),
19321977
x,
19331978
z,
1934-
event_dims))
1979+
event_dims,
1980+
named_axis))
19351981

19361982
p_couple = tf.exp(tf.minimum(0., log_couple_ratio))
19371983
coupling_proposed = util.random_uniform(

0 commit comments

Comments
 (0)