29
29
]
30
30
31
31
32
+ class InsufficientEnsembleSizeError (Exception ):
33
+ """Raise when the ensemble size is insufficient for a function."""
34
+
35
+
32
36
# Sample covariance. Handles differing shapes.
33
37
def _covariance (x , y = None ):
34
38
"""Sample covariance, assuming samples are the leftmost axis."""
@@ -304,6 +308,7 @@ def ensemble_kalman_filter_log_marginal_likelihood(
304
308
state ,
305
309
observation ,
306
310
observation_fn ,
311
+ perturbed_observations = True ,
307
312
seed = None ,
308
313
name = None ):
309
314
"""Ensemble Kalman filter log marginal likelihood.
@@ -332,6 +337,11 @@ def ensemble_kalman_filter_log_marginal_likelihood(
332
337
observation_fn: callable returning an instance of
333
338
`tfd.MultivariateNormalLinearOperator` along with an extra information
334
339
to be returned in the `EnsembleKalmanFilterState`.
340
+ perturbed_observations: Whether the marginal distribution `p(Y[t] | ...)`
341
+ is estimated using samples from the `observation_fn`'s distribution. If
342
+ `False`, the distribution's covariance matrix is used directly. This
343
+ latter choice is less common in the literature, but works even if the
344
+ ensemble size is smaller than the number of observations.
335
345
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
336
346
name: Python `str` name for ops created by this method.
337
347
Default value: `None`
@@ -340,6 +350,10 @@ def ensemble_kalman_filter_log_marginal_likelihood(
340
350
Returns:
341
351
log_marginal_likelihood: `Tensor` with same dtype as `state`.
342
352
353
+ Raises:
354
+ InsufficientEnsembleSizeError: If `perturbed_observations=True` and the
355
+ ensemble size is not at least one greater than the number of observations.
356
+
343
357
#### References
344
358
345
359
[1] Geir Evensen. Sequential data assimilation with a nonlinear
@@ -360,16 +374,37 @@ def ensemble_kalman_filter_log_marginal_likelihood(
360
374
361
375
observation = tf .convert_to_tensor (observation , dtype = common_dtype )
362
376
363
- if not isinstance (observation_particles_dist ,
364
- distributions .MultivariateNormalLinearOperator ):
365
- raise ValueError ('Expected `observation_fn` to return an instance of '
366
- '`MultivariateNormalLinearOperator`' )
367
-
368
- observation_particles = observation_particles_dist .sample (seed = seed )
369
- observation_dist = distributions .MultivariateNormalTriL (
370
- loc = tf .reduce_mean (observation_particles , axis = 0 ),
371
- scale_tril = tf .linalg .cholesky (_covariance (observation_particles )))
372
-
377
+ if perturbed_observations :
378
+ # With G the observation operator and B the batch shape,
379
+ # observation_particles = G(X) + η, where η ~ Normal(0, Γ).
380
+ # Both are shape [n_ensemble] + B + [n_observations]
381
+ observation_particles = observation_particles_dist .sample (seed = seed )
382
+ n_observations = observation_particles_dist .event_shape [0 ]
383
+ n_ensemble = observation_particles_dist .batch_shape [0 ]
384
+ if (n_ensemble is not None and n_observations is not None and
385
+ n_ensemble < n_observations + 1 ):
386
+ raise InsufficientEnsembleSizeError (
387
+ f'When `perturbed_observations=True`, ensemble size ({ n_ensemble } ) '
388
+ 'must be at least one greater than the number of observations '
389
+ f'({ n_observations } ), but it was not.' )
390
+ observation_dist = distributions .MultivariateNormalTriL (
391
+ loc = tf .reduce_mean (observation_particles , axis = 0 ),
392
+ # Cholesky(Cov(G(X) + η)), where Cov(..) is the ensemble covariance.
393
+ scale_tril = tf .linalg .cholesky (_covariance (observation_particles )))
394
+ else :
395
+ # predicted_observation = G(X),
396
+ # and is shape [n_ensemble] + B.
397
+ predicted_observation = observation_particles_dist .mean ()
398
+ observation_dist = distributions .MultivariateNormalTriL (
399
+ loc = tf .reduce_mean (predicted_observation , axis = 0 ), # ensemble mean
400
+ # Cholesky(Cov(G(X)) + Γ), where Cov(..) is the ensemble covariance.
401
+ scale_tril = tf .linalg .cholesky (
402
+ _covariance (predicted_observation ) +
403
+ _linop_covariance (observation_particles_dist ).to_dense ()))
404
+
405
+ # Above we computed observation_dist, the distribution of observations given
406
+ # the predictive distribution of states (e.g. states from previous time).
407
+ # Here we evaluate the log_prob on the actual observations.
373
408
return observation_dist .log_prob (observation )
374
409
375
410
0 commit comments