Skip to content

Commit 41c3c00

Browse files
srvasudejburnim
authored andcommitted
Improve backprop performance through experimental kalman filter, by changing out MVN log_prob calculation.
PiperOrigin-RevId: 579184615
1 parent dbe6138 commit 41c3c00

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tensorflow.compat.v2 as tf
2222

2323
from tensorflow_probability.python.distributions import mvn_tril
24+
from tensorflow_probability.python.internal import dtype_util
2425
from tensorflow_probability.python.internal import prefer_static as ps
2526
from tensorflow_probability.python.internal import samplers
2627
from tensorflow_probability.python.math import linalg
@@ -625,11 +626,8 @@ def kalman_filter(transition_matrix,
625626
axis=0),
626627
added_cov=time_dep.observation_cov)
627628

628-
# TODO(srvasude): The JVP for this can be implemented more efficiently.
629-
log_likelihoods = mvn_tril.MultivariateNormalTriL(
630-
loc=observation_means,
631-
scale_tril=tf.linalg.cholesky(observation_covs)).log_prob(
632-
observation.y)
629+
log_likelihoods = _mvn_log_prob(
630+
observation_means, observation_covs, observation.y)
633631
if observation.mask is not None:
634632
log_likelihoods = tf.where(observation.mask,
635633
tf.zeros([], dtype=log_likelihoods.dtype),
@@ -644,6 +642,17 @@ def kalman_filter(transition_matrix,
644642
observation_covs)
645643

646644

645+
def _mvn_log_prob(mean, covariance, y):
646+
cholesky_matrix = tf.linalg.cholesky(covariance)
647+
log_prob = -0.5 * linalg.hpsd_quadratic_form_solvevec(
648+
covariance, y - mean, cholesky_matrix=cholesky_matrix)
649+
log_prob = log_prob - 0.5 * linalg.hpsd_logdet(
650+
covariance, cholesky_matrix=cholesky_matrix)
651+
event_dims = ps.shape(mean)[-1]
652+
return log_prob - 0.5 * event_dims * dtype_util.as_numpy_dtype(
653+
mean.dtype)(np.log(2 * np.pi))
654+
655+
647656
def _extract_batch_shape(x, sample_ndims, event_ndims):
648657
"""Slice out the batch component of `x`'s shape."""
649658
if x is None:

0 commit comments

Comments
 (0)