@@ -310,9 +310,10 @@ def call_fn(
310
310
Returns:
311
311
ret: Return value of `fn`.
312
312
"""
313
- if isinstance (args , collections .Sequence ) and not _is_namedtuple_like (args ):
313
+ if (isinstance (args , collections .abc .Sequence ) and
314
+ not _is_namedtuple_like (args )):
314
315
return fn (* args )
315
- elif isinstance (args , collections .Mapping ):
316
+ elif isinstance (args , collections .abc . Mapping ):
316
317
return fn (** args )
317
318
else :
318
319
return fn (args )
@@ -393,7 +394,7 @@ def call_potential_fn(
393
394
'A common solution is to adjust the `return`s in `fn` to '
394
395
'be `return args, ()`.' )
395
396
396
- if not isinstance (ret , collections .Sequence ) or len (ret ) != 2 :
397
+ if not isinstance (ret , collections .abc . Sequence ) or len (ret ) != 2 :
397
398
args_s = _tree_repr (args )
398
399
ret_s = _tree_repr (ret )
399
400
raise TypeError (
@@ -434,7 +435,7 @@ def call_transition_operator(
434
435
'A common solution is to adjust the `return`s in `fn` to '
435
436
'be `return args, ()`.' )
436
437
437
- if not isinstance (ret , collections .Sequence ) or len (ret ) != 2 :
438
+ if not isinstance (ret , collections .abc . Sequence ) or len (ret ) != 2 :
438
439
args_s = _tree_repr (args )
439
440
ret_s = _tree_repr (ret )
440
441
raise TypeError (
@@ -1185,6 +1186,7 @@ def persistent_metropolis_hastings_step(
1185
1186
def gaussian_momentum_sample (state : 'Optional[State]' = None ,
1186
1187
shape : 'Optional[IntTensor]' = None ,
1187
1188
dtype : 'Optional[DTypeNest]' = None ,
1189
+ named_axis : 'Optional[StringNest]' = None ,
1188
1190
seed = None ) -> 'State' :
1189
1191
"""Generates a sample from a Gaussian (Normal) momentum distribution.
1190
1192
@@ -1197,6 +1199,7 @@ def gaussian_momentum_sample(state: 'Optional[State]' = None,
1197
1199
output.
1198
1200
shape: A nest of shapes, which matches the output shapes.
1199
1201
dtype: A nest of dtypes, which matches the output dtypes.
1202
+ named_axis: Named axes of the state, same structure as `state`.
1200
1203
seed: For reproducibility.
1201
1204
1202
1205
Returns:
@@ -1206,24 +1209,30 @@ def gaussian_momentum_sample(state: 'Optional[State]' = None,
1206
1209
if dtype is None or shape is None :
1207
1210
shape = util .map_tree (lambda t : t .shape , state )
1208
1211
dtype = util .map_tree (lambda t : t .dtype , state )
1212
+ if named_axis is None :
1213
+ named_axis = util .map_tree (lambda _ : [], dtype )
1209
1214
1210
1215
num_seeds_needed = len (util .flatten_tree (dtype ))
1211
1216
seeds = list (util .split_seed (seed , num_seeds_needed ))
1212
1217
seeds = util .unflatten_tree (dtype , seeds )
1213
1218
1214
- def _one_part (dtype , shape , seed ):
1219
+ def _one_part (dtype , shape , seed , named_axis ):
1220
+ seed = backend .distribute_lib .fold_in_axis_index (seed , named_axis )
1215
1221
return util .random_normal (shape = shape , dtype = dtype , seed = seed )
1216
1222
1217
- return util .map_tree_up_to (dtype , _one_part , dtype , shape , seeds )
1223
+ return util .map_tree_up_to (dtype , _one_part , dtype , shape , seeds , named_axis )
1218
1224
1219
1225
1220
1226
def make_gaussian_kinetic_energy_fn (
1221
- chain_ndims : 'IntTensor' ) -> 'Callable[..., Tuple[tf.Tensor, TensorNest]]' :
1227
+ chain_ndims : 'IntTensor' ,
1228
+ named_axis : 'Optional[StringNest]' = None ,
1229
+ ) -> 'Callable[..., Tuple[tf.Tensor, TensorNest]]' :
1222
1230
"""Returns a function that computes the kinetic energy of a state.
1223
1231
1224
1232
Args:
1225
1233
chain_ndims: How many leading dimensions correspond to independent
1226
1234
particles.
1235
+ named_axis: Named axes of the state, same structure as `state`.
1227
1236
1228
1237
Returns:
1229
1238
kinetic_energy_fn: A callable that takes in the expanded state (see
@@ -1233,13 +1242,29 @@ def make_gaussian_kinetic_energy_fn(
1233
1242
1234
1243
@util .named_call
1235
1244
def kinetic_energy_fn (* args , ** kwargs ):
1245
+ state_args = (args , kwargs )
1236
1246
1237
- def one_component (x ):
1238
- return tf .reduce_sum (
1239
- tf .square (x ), axis = tuple (range (chain_ndims , len (x .shape ))))
1240
-
1241
- return (tf .add_n (
1242
- [one_component (x ) for x in util .flatten_tree ([args , kwargs ])]) / 2. ), ()
1247
+ if named_axis is None :
1248
+ named_axis_args = util .map_tree (lambda _ : [], state_args )
1249
+ else :
1250
+ # We need named_axis to line up with state, but state has been decomposed
1251
+ # into args, kwargs via call_fn which called this function. Normally, we'd
1252
+ # reconstruct the state via recover_state_from_args, but we don't have a
1253
+ # good reference structure (named_axis is no good as it can have tuples as
1254
+ # leaves). Instead, we go the other way, and decompose named_axis into
1255
+ # args, kwargs. These new objects are guaranteed to line up with the
1256
+ # decomposed state.
1257
+ named_axis_args = call_fn (lambda * args , ** kwargs : (args , kwargs ),
1258
+ named_axis )
1259
+
1260
+ def _one_part (x , named_axis ):
1261
+ return backend .distribute_lib .reduce_sum (
1262
+ tf .square (x ), tuple (range (chain_ndims , len (x .shape ))), named_axis )
1263
+
1264
+ return 0.5 * sum (
1265
+ util .flatten_tree (
1266
+ util .map_tree_up_to (state_args , _one_part , state_args ,
1267
+ named_axis_args ))), ()
1243
1268
1244
1269
return kinetic_energy_fn
1245
1270
@@ -1315,6 +1340,7 @@ def hamiltonian_monte_carlo_step(
1315
1340
energy_change_fn :
1316
1341
'Callable[[IntegratorState, IntegratorState, IntegratorExtras], '
1317
1342
'Tuple[FloatTensor, Any]]' = _default_hamiltonian_monte_carlo_energy_change_fn ,
1343
+ named_axis : 'Optional[StringNest]' = None ,
1318
1344
seed = None ,
1319
1345
) -> 'Tuple[HamiltonianMonteCarloState, HamiltonianMonteCarloExtra]' :
1320
1346
"""Hamiltonian Monte Carlo `TransitionOperator`.
@@ -1381,6 +1407,7 @@ def orig_target_log_prob_fn(x):
1381
1407
Computes the change in energy between current and proposed states. By
1382
1408
default, it just substracts the current and proposed energies. A typical
1383
1409
reason to override this is to improve numerical stability.
1410
+ named_axis: Named axes of the state, same structure as `hmc_state.state`.
1384
1411
seed: For reproducibility.
1385
1412
1386
1413
Returns:
@@ -1395,11 +1422,11 @@ def orig_target_log_prob_fn(x):
1395
1422
if kinetic_energy_fn is None :
1396
1423
kinetic_energy_fn = make_gaussian_kinetic_energy_fn (
1397
1424
len (target_log_prob .shape ) if target_log_prob .shape is not None else tf # pytype: disable=attribute-error
1398
- .rank (target_log_prob ))
1425
+ .rank (target_log_prob ), named_axis = named_axis )
1399
1426
1400
1427
if momentum_sample_fn is None :
1401
1428
momentum_sample_fn = lambda seed : gaussian_momentum_sample ( # pylint: disable=g-long-lambda
1402
- state = state , seed = seed )
1429
+ state = state , seed = seed , named_axis = named_axis )
1403
1430
1404
1431
if integrator_fn is None :
1405
1432
integrator_fn = lambda state : hamiltonian_integrator ( # pylint: disable=g-long-lambda
@@ -1817,12 +1844,14 @@ def _one_part(state, g, learning_rate):
1817
1844
def gaussian_proposal (
1818
1845
state : 'State' ,
1819
1846
scale : 'FloatNest' = 1. ,
1847
+ named_axis : 'Optional[StringNest]' = None ,
1820
1848
seed : 'Optional[Any]' = None ) -> 'Tuple[State, Tuple[Tuple[()], float]]' :
1821
1849
"""Axis-aligned gaussian random-walk proposal.
1822
1850
1823
1851
Args:
1824
1852
state: Current state.
1825
1853
scale: Scale of the proposal.
1854
+ named_axis: Named axes of the state, same structure as `state`.
1826
1855
seed: Random seed.
1827
1856
1828
1857
Returns:
@@ -1832,13 +1861,16 @@ def gaussian_proposal(
1832
1861
scale = maybe_broadcast_structure (scale , state )
1833
1862
num_parts = len (util .flatten_tree (state ))
1834
1863
seeds = util .unflatten_tree (state , util .split_seed (seed , num_parts ))
1864
+ if named_axis is None :
1865
+ named_axis = util .map_tree (lambda _ : [], state )
1835
1866
1836
- new_state = util .map_tree (
1837
- lambda x , scale , seed : x + scale * util .random_normal ( # pylint: disable=g-long-lambda
1838
- x .shape , x .dtype , seed ),
1839
- state ,
1840
- scale ,
1841
- seeds )
1867
+ def _sample_part (x , scale , seed , named_axis ):
1868
+ seed = backend .distribute_lib .fold_in_axis_index (seed , named_axis )
1869
+ return x + scale * util .random_normal ( # pylint: disable=g-long-lambda
1870
+ x .shape , x .dtype , seed )
1871
+
1872
+ new_state = util .map_tree_up_to (state , _sample_part , state , scale , seeds ,
1873
+ named_axis )
1842
1874
1843
1875
return new_state , ((), 0. )
1844
1876
@@ -1854,6 +1886,7 @@ def maximal_reflection_coupling_proposal(
1854
1886
state : 'State' ,
1855
1887
chain_ndims : 'int' = 0 ,
1856
1888
scale : 'FloatNest' = 1 ,
1889
+ named_axis : 'Optional[StringNest]' = None ,
1857
1890
epsilon : 'FloatTensor' = 1e-20 ,
1858
1891
seed : 'Optional[Any]' = None
1859
1892
) -> 'Tuple[State, Tuple[MaximalReflectiveCouplingProposalExtra, float]]' :
@@ -1869,11 +1902,15 @@ def maximal_reflection_coupling_proposal(
1869
1902
dimension such that `chain_i` is coupled with `chain_i + num_chains`, where
1870
1903
`num_chains = state.shape[0] // 2`
1871
1904
1905
+ This function supports SPMD via sharded states in the same sense as TensorFlow
1906
+ Probability's `tfp.experimental.distribute.Sharded`.
1907
+
1872
1908
Args:
1873
1909
state: Current state of the two sets of chains.
1874
1910
chain_ndims: How many leading dimensions correspond to independent chains
1875
1911
(not counting the first one).
1876
1912
scale: Scale of the proposal.
1913
+ named_axis: Shard axes names, used for SPMD.
1877
1914
epsilon: Small offset for numerical stability.
1878
1915
seed: Random seed.
1879
1916
@@ -1887,51 +1924,60 @@ def maximal_reflection_coupling_proposal(
1887
1924
Retrieved from http://arxiv.org/abs/2102.01790
1888
1925
"""
1889
1926
1927
+ _sum = backend .distribute_lib .reduce_sum # pylint: disable=invalid-name
1928
+
1890
1929
def _struct_sum (s ):
1891
1930
return sum (util .flatten_tree (s ))
1892
1931
1932
+ if named_axis is None :
1933
+ named_axis = util .map_tree (lambda _ : [], state )
1893
1934
scale = maybe_broadcast_structure (scale , state )
1894
1935
num_chains = util .flatten_tree (state )[0 ].shape [0 ] // 2
1895
1936
mu1 = util .map_tree (lambda x : x [:num_chains ], state )
1896
1937
mu2 = util .map_tree (lambda x : x [num_chains :], state )
1897
1938
event_dims = util .map_tree (
1898
- lambda x : tuple (range (chain_ndims , len (x .shape ))), # pylint: disable=g-long-lambda
1939
+ lambda x : tuple (range (1 + chain_ndims , len (x .shape ))),
1899
1940
mu1 )
1900
1941
z = util .map_tree (lambda s , x1 , x2 : (x1 - x2 ) / s , scale , mu1 , mu2 )
1901
1942
z_norm = tf .sqrt (
1902
1943
_struct_sum (
1903
- util .map_tree_up_to (z , lambda z , ed : tf . reduce_sum (tf .square (z ), ed ),
1904
- z , event_dims )))
1944
+ util .map_tree_up_to (z , lambda z , ed , na : _sum (tf .square (z ), ed , na ),
1945
+ z , event_dims , named_axis )))
1905
1946
e = util .map_tree (
1906
1947
lambda z : z / # pylint: disable=g-long-lambda
1907
1948
(tf .reshape (z_norm , z_norm .shape + (1 ,) *
1908
1949
(len (z .shape ) - len (z_norm .shape ))) + epsilon ),
1909
1950
z )
1910
- batch_shape = util .flatten_tree (mu1 )[0 ].shape [: chain_ndims ]
1951
+ batch_shape = util .flatten_tree (mu1 )[0 ].shape [1 : 1 + chain_ndims ]
1911
1952
1912
1953
num_parts = len (util .flatten_tree (state ))
1913
1954
all_seeds = util .split_seed (seed , num_parts + 1 )
1914
1955
x_seeds = util .unflatten_tree (state , all_seeds [:num_parts ])
1915
1956
couple_seed = all_seeds [- 1 ]
1916
1957
1917
- x = util .map_tree (lambda x , seed : util .random_normal (x .shape , x .dtype , seed ),
1918
- mu1 , x_seeds )
1958
+ def _sample_part (x , seed , named_axis ):
1959
+ seed = backend .distribute_lib .fold_in_axis_index (seed , named_axis )
1960
+ return util .random_normal (x .shape , x .dtype , seed )
1961
+
1962
+ x = util .map_tree_up_to (mu1 , _sample_part , mu1 , x_seeds , named_axis )
1919
1963
1920
1964
e_dot_x = _struct_sum (
1921
1965
util .map_tree_up_to (
1922
1966
x ,
1923
- lambda x , e , ed : tf . reduce_sum (x * e , ed ), # pylint: disable=g-long-lambda
1967
+ lambda x , e , ed , na : _sum (x * e , ed , na ),
1924
1968
x ,
1925
1969
e ,
1926
- event_dims ))
1970
+ event_dims ,
1971
+ named_axis ))
1927
1972
1928
1973
log_couple_ratio = _struct_sum (
1929
1974
util .map_tree_up_to (
1930
1975
x ,
1931
- lambda x , z , ed : - tf . reduce_sum (x * z + tf .square (z ) / 2 , ed ), # pylint: disable=g-long-lambda
1976
+ lambda x , z , ed , na : - _sum (x * z + tf .square (z ) / 2 , ed , na ),
1932
1977
x ,
1933
1978
z ,
1934
- event_dims ))
1979
+ event_dims ,
1980
+ named_axis ))
1935
1981
1936
1982
p_couple = tf .exp (tf .minimum (0. , log_couple_ratio ))
1937
1983
coupling_proposed = util .random_uniform (
0 commit comments