15
15
"""Gradient-based trajectory length adaptation kernel."""
16
16
17
17
import collections
18
+ import functools
18
19
19
20
import tensorflow .compat .v2 as tf
20
21
@@ -89,6 +90,16 @@ def _map_structure_up_to_with_axes(structure, fn, *args,
89
90
experimental_shard_axis_names )
90
91
91
92
93
+ def _reduce_with_axes (index_op , name_op , x , axis_idx = None , axis_names = None ):
94
+ return name_op (index_op (x , axis_idx ), axis_names )
95
+
96
+
97
+ _reduce_sum_with_axes = functools .partial (_reduce_with_axes , tf .reduce_sum ,
98
+ distribute_lib .psum )
99
+ _reduce_mean_with_axes = functools .partial (_reduce_with_axes , tf .reduce_mean ,
100
+ distribute_lib .pmean )
101
+
102
+
92
103
def hmc_like_num_leapfrog_steps_getter_fn (kernel_results ):
93
104
"""Getter for `num_leapfrog_steps` so it can be inspected."""
94
105
return unnest .get_innermost (kernel_results , 'num_leapfrog_steps' )
@@ -132,7 +143,8 @@ def chees_criterion(previous_state,
132
143
proposed_state ,
133
144
accept_prob ,
134
145
validate_args = False ,
135
- experimental_shard_axis_names = None ):
146
+ experimental_shard_axis_names = None ,
147
+ experimental_chain_axis_names = None ):
136
148
"""The ChEES criterion from [1].
137
149
138
150
ChEES stands for Change in the Estimator of the Expected Square.
@@ -166,6 +178,8 @@ def chees_criterion(previous_state,
166
178
validate_args: Whether to perform non-static argument validation.
167
179
experimental_shard_axis_names: A structure of string names indicating how
168
180
members of the state are sharded.
181
+ experimental_chain_axis_names: A string or list of string names indicating
182
+ how batches of chains are sharded.
169
183
170
184
Returns:
171
185
chees: The value of the ChEES criterion.
@@ -182,7 +196,13 @@ def chees_criterion(previous_state,
182
196
"""
183
197
batch_ndims = ps .rank (accept_prob )
184
198
batch_axes = ps .range (batch_ndims , dtype = tf .int32 )
185
- num_chains = ps .size (accept_prob )
199
+ experimental_chain_axis_names = distribute_lib .canonicalize_axis_name (
200
+ experimental_chain_axis_names )
201
+ # Number of total chains is local batch size * distributed axis size
202
+ local_axis_size = ps .maximum (ps .size (accept_prob ), 1 )
203
+ distributed_axis_size = int (ps .reduce_prod ([
204
+ distribute_lib .get_axis_size (a ) for a in experimental_chain_axis_names ]))
205
+ num_chains = local_axis_size * distributed_axis_size
186
206
num_chains_ = tf .get_static_value (num_chains )
187
207
if num_chains_ is not None :
188
208
if num_chains_ < 2 :
@@ -199,7 +219,9 @@ def chees_criterion(previous_state,
199
219
def _center_previous_state (x ):
200
220
# The empirical mean here is a stand-in for the true mean, so we drop the
201
221
# gradient that flows through this term.
202
- return x - tf .stop_gradient (tf .reduce_mean (x , axis = batch_axes ))
222
+ x_mean = _reduce_mean_with_axes (
223
+ x , batch_axes , experimental_chain_axis_names )
224
+ return x - tf .stop_gradient (x_mean )
203
225
204
226
def _center_proposed_state (x ):
205
227
# The empirical mean here is a stand-in for the true mean, so we drop the
@@ -216,8 +238,10 @@ def _center_proposed_state(x):
216
238
# If all accept_prob's are zero, the x_center will have a nonsense value,
217
239
# but we'll discard the resultant gradients later on, so it's fine.
218
240
x_center = (
219
- tf .reduce_sum (expanded_accept_prob * x_safe , axis = batch_axes ) /
220
- (tf .reduce_sum (expanded_accept_prob , axis = batch_axes ) + 1e-20 ))
241
+ _reduce_sum_with_axes (expanded_accept_prob * x_safe , batch_axes ,
242
+ experimental_chain_axis_names ) /
243
+ (_reduce_sum_with_axes (expanded_accept_prob , batch_axes ,
244
+ experimental_chain_axis_names ) + 1e-20 ))
221
245
222
246
return x - tf .stop_gradient (x_center )
223
247
@@ -358,6 +382,7 @@ def __init__(
358
382
proposed_state_getter_fn = hmc_like_proposed_state_getter_fn ,
359
383
validate_args = False ,
360
384
experimental_shard_axis_names = None ,
385
+ experimental_chain_axis_names = None ,
361
386
name = None ):
362
387
"""Creates the trajectory length adaptation kernel.
363
388
@@ -414,6 +439,8 @@ def __init__(
414
439
outputs.
415
440
experimental_shard_axis_names: A structure of string names indicating how
416
441
members of the state are sharded.
442
+ experimental_chain_axis_names: A string or list of string names indicating
443
+ how batches of chains are sharded.
417
444
name: Python `str` name prefixed to Ops created by this class. Default:
418
445
'simple_step_size_adaptation'.
419
446
@@ -452,6 +479,7 @@ class docstring).
452
479
proposed_state_getter_fn = hmc_like_proposed_state_getter_fn ,
453
480
validate_args = validate_args ,
454
481
experimental_shard_axis_names = experimental_shard_axis_names ,
482
+ experimental_chain_axis_names = experimental_chain_axis_names ,
455
483
name = name ,
456
484
)
457
485
@@ -468,12 +496,15 @@ def num_adaptation_steps(self):
468
496
return self ._parameters ['num_adaptation_steps' ]
469
497
470
498
def criterion_fn (self , previous_state , proposed_state , accept_prob ):
471
- if self .experimental_shard_axis_names is None :
472
- return self ._parameters ['criterion_fn' ](previous_state , proposed_state ,
473
- accept_prob )
474
- return self ._parameters ['criterion_fn' ](
475
- previous_state , proposed_state , accept_prob ,
476
- experimental_shard_axis_names = self .experimental_shard_axis_names )
499
+ kwargs = {}
500
+ if self .experimental_chain_axis_names is not None :
501
+ kwargs ['experimental_chain_axis_names' ] = (
502
+ self .experimental_chain_axis_names )
503
+ if self .experimental_shard_axis_names is not None :
504
+ kwargs ['experimental_shard_axis_names' ] = (
505
+ self .experimental_shard_axis_names )
506
+ return self ._parameters ['criterion_fn' ](previous_state , proposed_state ,
507
+ accept_prob , ** kwargs )
477
508
478
509
@property
479
510
def max_leapfrog_steps (self ):
@@ -567,7 +598,8 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
567
598
step_size = step_size ,
568
599
criterion_fn = self .criterion_fn ,
569
600
max_leapfrog_steps = self .max_leapfrog_steps ,
570
- experimental_shard_axis_names = self .experimental_shard_axis_names )
601
+ experimental_shard_axis_names = self .experimental_shard_axis_names ,
602
+ experimental_chain_axis_names = self .experimental_chain_axis_names )
571
603
572
604
# Undo the effect of adaptation if we're not in the burnin phase. We keep
573
605
# the criterion, however, as that's a diagnostic. We also keep the
@@ -623,9 +655,16 @@ def is_calibrated(self):
623
655
def experimental_shard_axis_names (self ):
624
656
return self ._parameters ['experimental_shard_axis_names' ]
625
657
658
+ @property
659
+ def experimental_chain_axis_names (self ):
660
+ return self ._parameters ['experimental_chain_axis_names' ]
661
+
626
662
def experimental_with_shard_axes (self , shard_axis_names ):
627
663
return self .copy (experimental_shard_axis_names = shard_axis_names )
628
664
665
+ def experimental_with_chain_axes (self , chain_axis_names ):
666
+ return self .copy (experimental_chain_axis_names = chain_axis_names )
667
+
629
668
630
669
def _forbid_inner_transformed_kernel (inner_kernel ):
631
670
"""Forbids inner kernel from containing `TransformedTransitionKernel`."""
@@ -669,7 +708,8 @@ def _update_trajectory_grad(previous_kernel_results, previous_state,
669
708
proposed_state , proposed_velocity ,
670
709
trajectory_jitter , accept_prob , step_size ,
671
710
criterion_fn , max_leapfrog_steps ,
672
- experimental_shard_axis_names = None ):
711
+ experimental_shard_axis_names = None ,
712
+ experimental_chain_axis_names = None ):
673
713
"""Updates the trajectory length."""
674
714
# Compute criterion grads.
675
715
def leapfrog_action (dt ):
@@ -693,12 +733,16 @@ def adjust_state(x, v, shard_axes=None):
693
733
trajectory_grad *= trajectory_jitter
694
734
695
735
# Weight by acceptance probability.
736
+ experimental_chain_axis_names = distribute_lib .canonicalize_axis_name (
737
+ experimental_chain_axis_names )
696
738
trajectory_grad = tf .where (accept_prob > 1e-4 , trajectory_grad , 0. )
697
739
trajectory_grad = tf .where (
698
740
tf .math .is_finite (trajectory_grad ), trajectory_grad , 0. )
699
741
trajectory_grad = (
700
- tf .reduce_sum (trajectory_grad * accept_prob ) /
701
- tf .reduce_sum (accept_prob + 1e-20 ))
742
+ _reduce_sum_with_axes (trajectory_grad * accept_prob ,
743
+ None , experimental_chain_axis_names ) /
744
+ _reduce_sum_with_axes (accept_prob + 1e-20 , None ,
745
+ experimental_chain_axis_names ))
702
746
703
747
# Compute Adam/RMSProp step size.
704
748
dtype = previous_kernel_results .adaptation_rate .dtype
0 commit comments