Skip to content

Commit 36d8854

Browse files
davmretensorflower-gardener
authored andcommitted
Support unbiased gradients through particle filtering via stop-gradient resampling.
This implements the approach proposed by Adam Scibior, Vaden Masrani, and Frank Wood in "Differentiable Particle Filtering without Modifying the Forward Pass" (2021, https://arxiv.org/abs/2106.10314). PiperOrigin-RevId: 384336880
1 parent 2eafe29 commit 36d8854

File tree

5 files changed

+113
-19
lines changed

5 files changed

+113
-19
lines changed

tensorflow_probability/python/experimental/mcmc/particle_filter.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ def _default_trace_fn(state, kernel_results):
9494
The default behavior resamples particles when the current effective
9595
sample size falls below half the total number of particles.
9696
Default value: `tfp.experimental.mcmc.ess_below_threshold`.
97+
unbiased_gradients: If `True`, use the stop-gradient
98+
resampling trick of Scibior, Masrani, and Wood [{scibor_ref_idx}] to
99+
correct for gradient bias introduced by the discrete resampling step. This
100+
will generally increase the variance of stochastic gradients.
101+
Default value: `True`.
97102
rejuvenation_kernel_fn: optional Python `callable` with signature
98103
`transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn)`
99104
where `target_log_prob_fn` is a provided callable evaluating
@@ -112,7 +117,7 @@ def _default_trace_fn(state, kernel_results):
112117

113118

114119
@docstring_util.expand_docstring(
115-
particle_filter_arg_str=particle_filter_arg_str)
120+
particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=2))
116121
def infer_trajectories(observations,
117122
initial_state_prior,
118123
transition_fn,
@@ -122,6 +127,7 @@ def infer_trajectories(observations,
122127
proposal_fn=None,
123128
resample_fn=weighted_resampling.resample_systematic,
124129
resample_criterion_fn=smc_kernel.ess_below_threshold,
130+
unbiased_gradients=True,
125131
rejuvenation_kernel_fn=None,
126132
num_transitions_per_observation=1,
127133
seed=None,
@@ -224,6 +230,9 @@ def observation_fn(_, state):
224230
filtering and smoothing: Fifteen years later.
225231
_Handbook of nonlinear filtering_, 12(656-704), 2009.
226232
https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf
233+
[2] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle
234+
Filtering without Modifying the Forward Pass. _arXiv preprint
235+
arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314
227236
228237
"""
229238
with tf.name_scope(name or 'infer_trajectories') as name:
@@ -242,6 +251,7 @@ def observation_fn(_, state):
242251
proposal_fn=proposal_fn,
243252
resample_fn=resample_fn,
244253
resample_criterion_fn=resample_criterion_fn,
254+
unbiased_gradients=unbiased_gradients,
245255
rejuvenation_kernel_fn=rejuvenation_kernel_fn,
246256
num_transitions_per_observation=num_transitions_per_observation,
247257
trace_fn=_default_trace_fn,
@@ -265,7 +275,7 @@ def observation_fn(_, state):
265275

266276

267277
@docstring_util.expand_docstring(
268-
particle_filter_arg_str=particle_filter_arg_str)
278+
particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=1))
269279
def particle_filter(observations,
270280
initial_state_prior,
271281
transition_fn,
@@ -275,6 +285,7 @@ def particle_filter(observations,
275285
proposal_fn=None,
276286
resample_fn=weighted_resampling.resample_systematic,
277287
resample_criterion_fn=smc_kernel.ess_below_threshold,
288+
unbiased_gradients=True,
278289
rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument
279290
num_transitions_per_observation=1,
280291
trace_fn=_default_trace_fn,
@@ -324,6 +335,12 @@ def particle_filter(observations,
324335
`trace_criterion_fn==None`, this is computed from the final step;
325336
otherwise, each Tensor will have initial dimension `num_steps_traced`
326337
and stacks the traced results across all steps.
338+
339+
#### References
340+
341+
[1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle
342+
Filtering without Modifying the Forward Pass. _arXiv preprint
343+
arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314
327344
"""
328345

329346
init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter')
@@ -356,7 +373,8 @@ def particle_filter(observations,
356373
kernel = smc_kernel.SequentialMonteCarlo(
357374
propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
358375
resample_fn=resample_fn,
359-
resample_criterion_fn=resample_criterion_fn)
376+
resample_criterion_fn=resample_criterion_fn,
377+
unbiased_gradients=unbiased_gradients)
360378

361379
# Use `trace_scan` rather than `sample_chain` directly because the latter
362380
# would force us to trace the state history (with or without thinning),

tensorflow_probability/python/experimental/mcmc/sample_sequential_monte_carlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def smc_body_fn(stage, state, smc_kernel_result):
541541
smc_kernel_result.particle_info.log_scalings, axis=0))
542542
)
543543
(resampled_state,
544-
resampled_particle_info), _ = weighted_resampling.resample(
544+
resampled_particle_info), _, _ = weighted_resampling.resample(
545545
particles=(state, smc_kernel_result.particle_info),
546546
log_weights=log_weights,
547547
resample_fn=resample_fn,

tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(self,
140140
propose_and_update_log_weights_fn,
141141
resample_fn=weighted_resampling.resample_systematic,
142142
resample_criterion_fn=ess_below_threshold,
143+
unbiased_gradients=True,
143144
name=None):
144145
"""Initializes a sequential Monte Carlo transition kernel.
145146
@@ -180,11 +181,23 @@ def __init__(self,
180181
default behavior is to resample particles when the effective
181182
sample size falls below half of the total number of particles.
182183
Default value: `tfp.experimental.mcmc.ess_below_threshold`.
184+
unbiased_gradients: If `True`, use the stop-gradient
185+
resampling trick of Scibior, Masrani, and Wood [{scibor_ref_idx}] to
186+
correct for gradient bias introduced by the discrete resampling step.
187+
This will generally increase the variance of stochastic gradients.
188+
Default value: `True`.
183189
name: Python `str` name for ops created by this kernel.
190+
191+
#### References
192+
193+
[1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle
194+
Filtering without Modifying the Forward Pass. _arXiv preprint
195+
arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314
184196
"""
185197
self._propose_and_update_log_weights_fn = propose_and_update_log_weights_fn
186198
self._resample_fn = resample_fn
187199
self._resample_criterion_fn = resample_criterion_fn
200+
self._unbiased_gradients = unbiased_gradients
188201
self._name = name or 'SequentialMonteCarlo'
189202

190203
@property
@@ -203,6 +216,10 @@ def propose_and_update_log_weights_fn(self):
203216
def resample_criterion_fn(self):
204217
return self._resample_criterion_fn
205218

219+
@property
220+
def unbiased_gradients(self):
221+
return self._unbiased_gradients
222+
206223
@property
207224
def resample_fn(self):
208225
return self._resample_fn
@@ -234,7 +251,6 @@ def one_step(self, state, kernel_results, seed=None):
234251
proposal_seed, resample_seed = samplers.split_seed(seed)
235252

236253
state = WeightedParticles(*state) # Canonicalize.
237-
num_particles = ps.size0(state.log_weights)
238254

239255
# Propose new particles and update weights for this step, unless it's
240256
# the initial step, in which case, use the user-provided initial
@@ -266,19 +282,26 @@ def one_step(self, state, kernel_results, seed=None):
266282
# needed---but we're ultimately interested in adaptive resampling
267283
# for statistical (not computational) purposes, so this isn't a
268284
# dealbreaker.
269-
resampled_particles, resample_indices = weighted_resampling.resample(
270-
state.particles,
271-
state.log_weights,
272-
self.resample_fn,
285+
[
286+
resampled_particles,
287+
resample_indices,
288+
weights_after_resampling
289+
] = weighted_resampling.resample(
290+
particles=state.particles,
291+
# The `stop_gradient` here does not affect discrete resampling
292+
# (which is nondifferentiable anyway), but avoids canceling out
293+
# the gradient signal from the 'target' log weights, as described in
294+
# Scibior, Masrani, and Wood (2021).
295+
log_weights=tf.stop_gradient(state.log_weights),
296+
resample_fn=self.resample_fn,
297+
target_log_weights=(normalized_log_weights
298+
if self.unbiased_gradients else None),
273299
seed=resample_seed)
274-
uniform_weights = tf.fill(
275-
ps.shape(state.log_weights),
276-
value=-tf.math.log(tf.cast(num_particles, state.log_weights.dtype)))
277300
(resampled_particles,
278301
resample_indices,
279302
log_weights) = tf.nest.map_structure(
280303
lambda r, p: tf.where(do_resample, r, p),
281-
(resampled_particles, resample_indices, uniform_weights),
304+
(resampled_particles, resample_indices, weights_after_resampling),
282305
(state.particles, _dummy_indices_like(resample_indices),
283306
normalized_log_weights))
284307

tensorflow_probability/python/experimental/mcmc/weighted_resampling.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
]
3535

3636

37-
def resample(particles, log_weights, resample_fn, seed=None):
37+
def resample(particles, log_weights, resample_fn, target_log_weights=None,
38+
seed=None):
3839
"""Resamples the current particles according to provided weights.
3940
4041
Args:
@@ -47,21 +48,46 @@ def resample(particles, log_weights, resample_fn, seed=None):
4748
Use 'resample_independent' for independent resamples.
4849
Use 'resample_stratified' for stratified resampling.
4950
Use 'resample_systematic' for systematic resampling.
51+
target_log_weights: optional float `Tensor` of the same shape and dtype as
52+
`log_weights`, specifying the target measure on `particles` if this is
53+
different from that implied by normalizing `log_weights`. The
54+
returned `log_weights_after_resampling` will represent this measure. If
55+
`None`, the target measure is implicitly taken to be the normalized
56+
log weights (`log_weights - tf.reduce_logsumexp(log_weights, axis=0)`).
57+
Default value: `None`.
5058
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
5159
5260
Returns:
5361
resampled_particles: Nested structure of `Tensor`s, matching `particles`.
5462
resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`.
63+
log_weights_after_resampling: float `Tensor` of same shape and dtype as
64+
`log_weights`, such that weighted sums of the resampled particles are
65+
equal (in expectation over the resampling step) to weighted sums of
66+
the original particles:
67+
`E [ exp(log_weights_after_resampling) * some_fn(resampled_particles) ]
68+
== exp(target_log_weights) * some_fn(particles)`.
69+
If no `target_log_weights` was specified, the log weights after
70+
resampling are uniformly equal to `-log(num_particles)`.
5571
"""
5672
with tf.name_scope('resample'):
5773
num_particles = ps.size0(log_weights)
74+
log_num_particles = tf.math.log(tf.cast(num_particles, log_weights.dtype))
75+
76+
# Normalize the weights and sample the ancestral indices.
5877
log_probs = tf.math.log_softmax(log_weights, axis=0)
5978
resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed)
60-
resampled_particles = tf.nest.map_structure(
61-
lambda x: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda
62-
x, resampled_indices, axis=0),
63-
particles)
64-
return resampled_particles, resampled_indices
79+
80+
gather_ancestors = lambda x: ( # pylint: disable=g-long-lambda
81+
mcmc_util.index_remapping_gather(x, resampled_indices, axis=0))
82+
resampled_particles = tf.nest.map_structure(gather_ancestors, particles)
83+
if target_log_weights is None:
84+
log_weights_after_resampling = tf.fill(ps.shape(log_weights),
85+
-log_num_particles)
86+
else:
87+
importance_weights = target_log_weights - log_probs - log_num_particles
88+
log_weights_after_resampling = tf.nest.map_structure(
89+
gather_ancestors, importance_weights)
90+
return resampled_particles, resampled_indices, log_weights_after_resampling
6591

6692

6793
# TODO(b/153689734): rewrite so as not to use `move_dimension`.

tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import tensorflow_probability as tfp
2424
from tensorflow_probability.python.experimental.mcmc.weighted_resampling import _resample_using_log_points
2525
from tensorflow_probability.python.experimental.mcmc.weighted_resampling import _scatter_nd_batch
26+
from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample
2627
from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_deterministic_minimum_error
2728
from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_independent
2829
from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_stratified
@@ -275,6 +276,32 @@ def test_resample_using_extremal_log_points(self):
275276
log_probs_end, sample_shape, log_points_almost_one)
276277
self.assertAllEqual(indices, tf.fill([n], n - 1))
277278

279+
def resample_with_target_distribution(self):
280+
particles = np.linspace(0., 500., num=2500, dtype=np.float32)
281+
log_weights = tfd.Poisson(20.).log_prob(particles)
282+
283+
# Resample particles to target a Poisson(20.) distribution.
284+
new_particles, _, new_log_weights = resample(
285+
particles, log_weights,
286+
resample_fn=resample_systematic,
287+
seed=test_util.test_seed(sampler_type='stateless'))
288+
self.assertAllClose(tf.reduce_mean(new_particles), 20., atol=1e-2)
289+
self.assertAllClose(
290+
tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles),
291+
20.,
292+
atol=1e-2)
293+
294+
# Reweight the resampled particles to target a Poisson(30.) distribution.
295+
new_particles, _, new_log_weights = resample(
296+
particles, log_weights,
297+
resample_fn=resample_systematic,
298+
target_log_weights=tfd.Poisson(30).log_prob(particles),
299+
seed=test_util.test_seed(sampler_type='stateless'))
300+
self.assertAllClose(tf.reduce_mean(new_particles), 20., atol=1e-2)
301+
self.assertAllClose(
302+
tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles),
303+
30., atol=1.)
304+
278305
def maybe_compiler(self, f):
279306
if self.use_xla:
280307
return tf.function(f, autograph=False, jit_compile=True)

0 commit comments

Comments
 (0)