21
21
import tensorflow .compat .v2 as tf
22
22
23
23
from tensorflow_probability .python .distributions import mvn_tril
24
+ from tensorflow_probability .python .internal import dtype_util
24
25
from tensorflow_probability .python .internal import prefer_static as ps
25
26
from tensorflow_probability .python .internal import samplers
26
27
from tensorflow_probability .python .math import linalg
@@ -625,11 +626,8 @@ def kalman_filter(transition_matrix,
625
626
axis = 0 ),
626
627
added_cov = time_dep .observation_cov )
627
628
628
- # TODO(srvasude): The JVP for this can be implemented more efficiently.
629
- log_likelihoods = mvn_tril .MultivariateNormalTriL (
630
- loc = observation_means ,
631
- scale_tril = tf .linalg .cholesky (observation_covs )).log_prob (
632
- observation .y )
629
+ log_likelihoods = _mvn_log_prob (
630
+ observation_means , observation_covs , observation .y )
633
631
if observation .mask is not None :
634
632
log_likelihoods = tf .where (observation .mask ,
635
633
tf .zeros ([], dtype = log_likelihoods .dtype ),
@@ -644,6 +642,17 @@ def kalman_filter(transition_matrix,
644
642
observation_covs )
645
643
646
644
645
+ def _mvn_log_prob (mean , covariance , y ):
646
+ cholesky_matrix = tf .linalg .cholesky (covariance )
647
+ log_prob = - 0.5 * linalg .hpsd_quadratic_form_solvevec (
648
+ covariance , y - mean , cholesky_matrix = cholesky_matrix )
649
+ log_prob = log_prob - 0.5 * linalg .hpsd_logdet (
650
+ covariance , cholesky_matrix = cholesky_matrix )
651
+ event_dims = ps .shape (mean )[- 1 ]
652
+ return log_prob - 0.5 * event_dims * dtype_util .as_numpy_dtype (
653
+ mean .dtype )(np .log (2 * np .pi ))
654
+
655
+
647
656
def _extract_batch_shape (x , sample_ndims , event_ndims ):
648
657
"""Slice out the batch component of `x`'s shape."""
649
658
if x is None :
0 commit comments