Skip to content

Commit bc93267

Browse files
davmretensorflower-gardener
authored andcommitted
Fix bug in particle filter where static evaluation could block the gradient signal.
PiperOrigin-RevId: 384310858
1 parent 4de9d3c commit bc93267

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

tensorflow_probability/python/experimental/mcmc/particle_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def _compute_observation_log_weights(step,
500500
lambda x, step=step: tf.gather(x, observation_idx), observations)
501501

502502
log_weights = observation_fn(step, particles).log_prob(observation)
503-
return ps.where(step_has_observation,
503+
return tf.where(step_has_observation,
504504
log_weights,
505505
tf.zeros_like(log_weights))
506506

tensorflow_probability/python/experimental/mcmc/particle_filter_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,23 @@ def transition_fn_no_batch_shape(_, particles):
561561
run_filter(transition_fn=valid_transition_fn,
562562
proposal_fn=transition_fn_no_batch_shape))
563563

564+
@test_util.jax_disable_test_missing_functionality('Gradient of while_loop.')
565+
def test_marginal_likelihood_gradients_are_defined(self):
566+
567+
def marginal_log_likelihood(level_scale, noise_scale):
568+
_, _, _, lps = tfp.experimental.mcmc.particle_filter(
569+
observations=tf.convert_to_tensor([1., 2., 3., 4., 5.]),
570+
initial_state_prior=tfd.Normal(loc=0, scale=1.),
571+
transition_fn=lambda _, x: tfd.Normal(loc=x, scale=level_scale),
572+
observation_fn=lambda _, x: tfd.Normal(loc=x, scale=noise_scale),
573+
num_particles=4,
574+
seed=test_util.test_seed())
575+
return tf.reduce_sum(lps)
576+
577+
_, grads = tfp.math.value_and_gradient(marginal_log_likelihood, 1.0, 1.0)
578+
self.assertAllNotNone(grads)
579+
self.assertNotAllZero(grads)
580+
564581

565582
# TODO(b/186068104): add tests with dynamic shapes.
566583
class ParticleFilterTestFloat32(_ParticleFilterTest):

tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def one_step(self, state, kernel_results, seed=None):
277277
(resampled_particles,
278278
resample_indices,
279279
log_weights) = tf.nest.map_structure(
280-
lambda r, p: ps.where(do_resample, r, p),
280+
lambda r, p: tf.where(do_resample, r, p),
281281
(resampled_particles, resample_indices, uniform_weights),
282282
(state.particles, _dummy_indices_like(resample_indices),
283283
normalized_log_weights))

0 commit comments

Comments
 (0)