Skip to content

Commit bc6c411

Browse files
langmoretensorflower-gardener
authored andcommitted
Add a perturbed_observations option to ensemble_kalman_filter_log_marginal_likelihood.
If False, the observation covariance is computed in a less-stochastic manner that guarantees an SPD result, even with small ensemble sizes. The name "perturbed observations" is chosen because this corresponds to the (well-known) "perturbed observation" *update* step. There is no well-known name for this technique as applied to marginal likelihood (as I've done here), but borrowing the same name seems appropriate. PiperOrigin-RevId: 451771148
1 parent ac62542 commit bc6c411

File tree

2 files changed

+89
-12
lines changed

2 files changed

+89
-12
lines changed

tensorflow_probability/python/experimental/sequential/ensemble_kalman_filter.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
]
3030

3131

32+
class InsufficientEnsembleSizeError(Exception):
33+
"""Raise when the ensemble size is insufficient for a function."""
34+
35+
3236
# Sample covariance. Handles differing shapes.
3337
def _covariance(x, y=None):
3438
"""Sample covariance, assuming samples are the leftmost axis."""
@@ -304,6 +308,7 @@ def ensemble_kalman_filter_log_marginal_likelihood(
304308
state,
305309
observation,
306310
observation_fn,
311+
perturbed_observations=True,
307312
seed=None,
308313
name=None):
309314
"""Ensemble Kalman filter log marginal likelihood.
@@ -332,6 +337,11 @@ def ensemble_kalman_filter_log_marginal_likelihood(
332337
observation_fn: callable returning an instance of
333338
`tfd.MultivariateNormalLinearOperator` along with an extra information
334339
to be returned in the `EnsembleKalmanFilterState`.
340+
perturbed_observations: Whether the marginal distribution `p(Y[t] | ...)`
341+
is estimated using samples from the `observation_fn`'s distribution. If
342+
`False`, the distribution's covariance matrix is used directly. This
343+
latter choice is less common in the literature, but works even if the
344+
ensemble size is smaller than the number of observations.
335345
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
336346
name: Python `str` name for ops created by this method.
337347
Default value: `None`
@@ -340,6 +350,10 @@ def ensemble_kalman_filter_log_marginal_likelihood(
340350
Returns:
341351
log_marginal_likelihood: `Tensor` with same dtype as `state`.
342352
353+
Raises:
354+
InsufficientEnsembleSizeError: If `perturbed_observations=True` and the
355+
ensemble size is not at least one greater than the number of observations.
356+
343357
#### References
344358
345359
[1] Geir Evensen. Sequential data assimilation with a nonlinear
@@ -360,16 +374,37 @@ def ensemble_kalman_filter_log_marginal_likelihood(
360374

361375
observation = tf.convert_to_tensor(observation, dtype=common_dtype)
362376

363-
if not isinstance(observation_particles_dist,
364-
distributions.MultivariateNormalLinearOperator):
365-
raise ValueError('Expected `observation_fn` to return an instance of '
366-
'`MultivariateNormalLinearOperator`')
367-
368-
observation_particles = observation_particles_dist.sample(seed=seed)
369-
observation_dist = distributions.MultivariateNormalTriL(
370-
loc=tf.reduce_mean(observation_particles, axis=0),
371-
scale_tril=tf.linalg.cholesky(_covariance(observation_particles)))
372-
377+
if perturbed_observations:
378+
# With G the observation operator and B the batch shape,
379+
# observation_particles = G(X) + η, where η ~ Normal(0, Γ).
380+
# Both are shape [n_ensemble] + B + [n_observations]
381+
observation_particles = observation_particles_dist.sample(seed=seed)
382+
n_observations = observation_particles_dist.event_shape[0]
383+
n_ensemble = observation_particles_dist.batch_shape[0]
384+
if (n_ensemble is not None and n_observations is not None and
385+
n_ensemble < n_observations + 1):
386+
raise InsufficientEnsembleSizeError(
387+
f'When `perturbed_observations=True`, ensemble size ({n_ensemble}) '
388+
'must be at least one greater than the number of observations '
389+
f'({n_observations}), but it was not.')
390+
observation_dist = distributions.MultivariateNormalTriL(
391+
loc=tf.reduce_mean(observation_particles, axis=0),
392+
# Cholesky(Cov(G(X) + η)), where Cov(..) is the ensemble covariance.
393+
scale_tril=tf.linalg.cholesky(_covariance(observation_particles)))
394+
else:
395+
# predicted_observation = G(X),
396+
# and is shape [n_ensemble] + B.
397+
predicted_observation = observation_particles_dist.mean()
398+
observation_dist = distributions.MultivariateNormalTriL(
399+
loc=tf.reduce_mean(predicted_observation, axis=0), # ensemble mean
400+
# Cholesky(Cov(G(X)) + Γ), where Cov(..) is the ensemble covariance.
401+
scale_tril=tf.linalg.cholesky(
402+
_covariance(predicted_observation) +
403+
_linop_covariance(observation_particles_dist).to_dense()))
404+
405+
# Above we computed observation_dist, the distribution of observations given
406+
# the predictive distribution of states (e.g. states from previous time).
407+
# Here we evaluate the log_prob on the actual observations.
373408
return observation_dist.log_prob(observation)
374409

375410

tensorflow_probability/python/experimental/sequential/ensemble_kalman_filter_test.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def observation_fn(_, particles, extra):
269269
self.assertAllEqual(particles_shape[1:-1], log_ml.shape)
270270
self.assertIn('observation_count', state.extra)
271271
self.assertEqual(3 * i + 1, state.extra['observation_count'])
272+
self.assertFalse(np.any(np.isnan(self.evaluate(log_ml))))
272273

273274
log_ml_krazy_obs = tfs.ensemble_kalman_filter_log_marginal_likelihood(
274275
state,
@@ -293,6 +294,38 @@ def observation_fn(_, particles, extra):
293294
self.evaluate(tf.reduce_mean(state.particles['x'], axis=0)),
294295
rtol=0.05)
295296

297+
def test_log_marginal_likelihood_with_small_ensemble_no_perturb_obs(self):
298+
# With perturbed_observations=False, we should be able to handle the small
299+
# ensemble without NaN.
300+
301+
# Initialize an ensemble with that is smaller than the event size.
302+
seed_stream = test_util.test_seed_stream()
303+
n_ensemble = 3
304+
event_size = 5
305+
self.assertLess(n_ensemble, event_size)
306+
particles_shape = (n_ensemble, event_size)
307+
308+
particles = {
309+
'x':
310+
self.evaluate(
311+
tf.random.normal(shape=particles_shape, seed=seed_stream())),
312+
}
313+
314+
def observation_fn(_, particles, extra):
315+
return tfd.MultivariateNormalDiag(
316+
loc=particles['x'], scale_diag=[1e-2] * event_size), extra
317+
318+
# Marginal likelihood.
319+
log_ml = tfs.ensemble_kalman_filter_log_marginal_likelihood(
320+
state=tfs.EnsembleKalmanFilterState(
321+
step=0, particles=particles, extra={}),
322+
observation=tf.random.normal(shape=(event_size,), seed=seed_stream()),
323+
observation_fn=observation_fn,
324+
perturbed_observations=False,
325+
seed=test_util.test_seed())
326+
self.assertAllEqual(particles_shape[1:-1], log_ml.shape)
327+
self.assertFalse(np.any(np.isnan(self.evaluate(log_ml))))
328+
296329

297330
# Parameters defining a linear/Gaussian state space model.
298331
LinearModelParams = collections.namedtuple('LinearModelParams', [
@@ -484,8 +517,15 @@ def _enkf_solve(self, observation, enkf_params, predict_kwargs, update_kwargs,
484517
noise_level=[0.001, 0.1, 1.0],
485518
n_states=[2, 5],
486519
n_observations=[2, 5],
520+
perturbed_observations=[False, True],
487521
))
488-
def test_same_solution(self, noise_level, n_states, n_observations):
522+
def test_same_solution(
523+
self,
524+
noise_level,
525+
n_states,
526+
n_observations,
527+
perturbed_observations,
528+
):
489529
"""Check that the KF and EnKF solutions are the same."""
490530
# Tests pass with n_ensemble = 1e7. The KF vs. EnKF tolerance is
491531
# proportional to 1 / sqrt(n_ensemble), so this shows good agreement.
@@ -496,7 +536,9 @@ def test_same_solution(self, noise_level, n_states, n_observations):
496536
dtype = tf.float64
497537
predict_kwargs = {}
498538
update_kwargs = {}
499-
log_marginal_likelihood_kwargs = {}
539+
log_marginal_likelihood_kwargs = {
540+
'perturbed_observations': perturbed_observations
541+
}
500542

501543
linear_model_params = self._get_linear_model_params(
502544
noise_level=noise_level,

0 commit comments

Comments
 (0)