Skip to content

Commit 21c10c1

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Replace usages of tfd.MultivariateNormalFullCovariance in TFP.
PiperOrigin-RevId: 464185451
1 parent 677fd25 commit 21c10c1

File tree

12 files changed

+87
-81
lines changed

12 files changed

+87
-81
lines changed

discussion/neutra/neutra_kernel_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def testSingleTensor(self, bijector, dtype):
4141
base_mean = tf.convert_to_tensor(value=[1., 0], dtype=dtype)
4242
base_cov = tf.convert_to_tensor(value=[[1, 0.5], [0.5, 1]], dtype=dtype)
4343

44-
base_dist = tfd.MultivariateNormalFullCovariance(
45-
loc=base_mean, covariance_matrix=base_cov)
44+
base_dist = tfd.MultivariateNormalTriL(
45+
loc=base_mean, scale_tril=tf.linalg.cholesky(base_cov))
4646
target_dist = bijector(base_dist)
4747

4848
def debug_fn(*args):
@@ -90,8 +90,8 @@ def testNested(self, bijector):
9090
base_mean = tf.constant([1., 0])
9191
base_cov = tf.constant([[1, 0.5], [0.5, 1]])
9292

93-
dist_2d = tfd.MultivariateNormalFullCovariance(
94-
loc=base_mean, covariance_matrix=base_cov)
93+
dist_2d = tfd.MultivariateNormalTriL(
94+
loc=base_mean, scale_tril=tf.linalg.cholesky(base_cov))
9595
dist_4d = tfd.MultivariateNormalDiag(scale_diag=tf.ones(4))
9696

9797
target_dist = tfd.JointDistributionSequential([

tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Distributions_Tutorial.ipynb

Lines changed: 38 additions & 36 deletions
Large diffs are not rendered by default.

tensorflow_probability/python/distributions/linear_gaussian_ssm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,10 +1035,10 @@ def _forward_filter_sequential(self, x, mask=None, final_step_only=False):
10351035
log_marginal_likelihood=0, timestep=0))
10361036

10371037
# We could directly construct the batch Distributions
1038-
# filtered_marginals = tfd.MultivariateNormalFullCovariance(
1039-
# filtered_means, filtered_covs)
1040-
# predicted_marginals = tfd.MultivariateNormalFullCovariance(
1041-
# predicted_means, predicted_covs)
1038+
# filtered_marginals = tfd.MultivariateNormalTriL(
1039+
# filtered_means, tf.linalg.cholesky(filtered_covs))
1040+
# predicted_marginals = tfd.MultivariateNormalTriL(
1041+
# predicted_means, tf.linalg.cholesky(predicted_covs))
10421042
# but we choose not to: returning the raw means and covariances
10431043
# saves computation in Eager mode (avoiding an immediate
10441044
# Cholesky factorization that the user may not want) and aids

tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ def test_all_distributions_either_work_or_raise_error(self, dist_name, data):
122122

123123
@test_util.numpy_disable_gradient_test
124124
def test_multivariate_normal(self):
125-
d = tfd.MultivariateNormalFullCovariance(
126-
loc=[4., 8.], covariance_matrix=[[11., 0.099], [0.099, 0.1]])
125+
d = tfd.MultivariateNormalTriL(
126+
loc=[4., 8.],
127+
scale_tril=tf.linalg.cholesky([[11., 0.099], [0.099, 0.1]]))
127128
b = tfp.experimental.bijectors.make_distribution_bijector(d)
128129
self.assertDistributionIsApproximatelyStandardNormal(tfb.Invert(b)(d))
129130

tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,8 @@ def target(*args):
132132

133133

134134
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
135-
initialization = tfd.MultivariateNormalFullCovariance(
136-
loc=tf.zeros(event_size),
137-
covariance_matrix=tf.eye(event_size)).sample(
138-
batch_size, seed=4)
135+
initialization = tfd.MultivariateNormalDiag(
136+
loc=tf.zeros(event_size)).sample(batch_size, seed=4)
139137
samples, leapfrogs = run_nuts_chain(
140138
event_size, batch_size, num_steps=1,
141139
initial_state=initialization, **kwargs)

tensorflow_probability/python/experimental/mcmc/pnuts_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,8 @@ def run_chain():
120120

121121
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
122122
strm = test_util.test_seed_stream()
123-
initialization = tfd.MultivariateNormalFullCovariance(
124-
loc=tf.zeros(event_size),
125-
covariance_matrix=tf.eye(event_size)).sample(
126-
batch_size, seed=strm())
123+
initialization = tfd.MultivariateNormalDiag(
124+
loc=tf.zeros(event_size)).sample(batch_size, seed=strm())
127125
samples, _ = run_nuts_chain(
128126
event_size, batch_size, num_steps=1,
129127
initial_state=initialization, **kwargs)

tensorflow_probability/python/experimental/mcmc/with_reductions_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ def test_multivariate_normal_covariance_with_sample_chain(self):
240240
cov = [[0.36, 0.12, 0.06],
241241
[0.12, 0.29, -0.13],
242242
[0.06, -0.13, 0.26]]
243-
target = tfp.distributions.MultivariateNormalFullCovariance(
244-
loc=mu, covariance_matrix=cov
243+
target = tfp.distributions.MultivariateNormalTriL(
244+
loc=mu, scale_tril=tf.linalg.cholesky(cov)
245245
)
246246
fake_kernel = tfp.mcmc.HamiltonianMonteCarlo(
247247
target_log_prob_fn=target.log_prob,

tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,21 +233,23 @@ def batch_generator():
233233
(type(my_filter_results)(1, 2, 3, 2, 3, 2, 3), 2, 1))
234234

235235
# pylint: disable=g-long-lambda,cell-var-from-loop
236-
mvn = tfd.MultivariateNormalFullCovariance
237236
dist = tfd.LinearGaussianStateSpaceModel(
238237
num_timesteps=nsteps,
239238
transition_matrix=lambda t: tf.linalg.LinearOperatorFullMatrix(
240239
tf.gather(transition_matrix, t, axis=0)),
241-
transition_noise=lambda t: mvn(
240+
transition_noise=lambda t: tfd.MultivariateNormalTriL(
242241
loc=tf.gather(transition_mean, t, axis=0),
243-
covariance_matrix=tf.gather(transition_cov, t, axis=0)),
242+
scale_tril=tf.linalg.cholesky(
243+
tf.gather(transition_cov, t, axis=0))),
244244
observation_matrix=lambda t: tf.linalg.LinearOperatorFullMatrix(
245245
tf.gather(observation_matrix, t, axis=0)),
246-
observation_noise=lambda t: mvn(
246+
observation_noise=lambda t: tfd.MultivariateNormalTriL(
247247
loc=tf.gather(observation_mean, t, axis=0),
248-
covariance_matrix=tf.gather(observation_cov, t, axis=0)),
249-
initial_state_prior=mvn(loc=initial_mean,
250-
covariance_matrix=initial_cov),
248+
scale_tril=tf.linalg.cholesky(
249+
tf.gather(observation_cov, t, axis=0))),
250+
initial_state_prior=tfd.MultivariateNormalTriL(
251+
loc=initial_mean,
252+
scale_tril=tf.linalg.cholesky(initial_cov)),
251253
experimental_parallelize=False) # Compare against sequential filter.
252254
# pylint: enable=g-long-lambda,cell-var-from-loop
253255

tensorflow_probability/python/internal/backend/numpy/linalg_impl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def _band_part(input, num_lower, num_upper, name=None): # pylint: disable=redef
8484

8585
def _cholesky_solve(chol, rhs, name=None): # pylint: disable=unused-argument
8686
"""Scipy cho_solve does not broadcast, so we must do so explicitly."""
87+
chol = ops.convert_to_tensor(chol)
88+
rhs = ops.convert_to_tensor(rhs)
8789
if JAX_MODE: # But JAX uses XLA, which can do a batched solve.
8890
chol = chol + np.zeros(rhs.shape[:-2] + (1, 1), dtype=chol.dtype)
8991
rhs = rhs + np.zeros(chol.shape[:-2] + (1, 1), dtype=rhs.dtype)
@@ -368,7 +370,7 @@ def _triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): # pyl
368370

369371
cholesky = utils.copy_docstring(
370372
'tf.linalg.cholesky',
371-
lambda input, name=None: np.linalg.cholesky(input))
373+
lambda input, name=None: np.linalg.cholesky(ops.convert_to_tensor(input)))
372374

373375
cholesky_solve = utils.copy_docstring(
374376
'tf.linalg.cholesky_solve',

tensorflow_probability/python/mcmc/nuts_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,8 @@ def run_chain():
116116

117117
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
118118
strm = test_util.test_seed_stream()
119-
initialization = tfd.MultivariateNormalFullCovariance(
120-
loc=tf.zeros(event_size),
121-
covariance_matrix=tf.eye(event_size)).sample(
122-
batch_size, seed=strm())
119+
initialization = tfd.MultivariateNormalDiag(
120+
loc=tf.zeros(event_size)).sample(batch_size, seed=strm())
123121
samples, _ = run_nuts_chain(
124122
event_size, batch_size, num_steps=1,
125123
initial_state=initialization, **kwargs)

0 commit comments

Comments
 (0)