Skip to content

Commit ac62542

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Support MultivariateNormalPrecisionFactorLinearOperator as a weight prior.
PiperOrigin-RevId: 451466503
1 parent 68f626f commit ac62542

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,16 @@ class SpikeAndSlabSparseLinearRegression(sts_components.LinearRegression):
140140

141141
def __init__(self,
142142
design_matrix,
143-
weights_prior=None,
143+
weights_prior,
144144
sparse_weights_nonzero_prob=0.5,
145145
name=None):
146146
# Extract precision matrix from a multivariate normal prior.
147147
weights_prior_precision = None
148148
if hasattr(weights_prior, 'precision'):
149-
weights_prior_precision = weights_prior.precision()
149+
if isinstance(weights_prior.precision, tf.linalg.LinearOperator):
150+
weights_prior_precision = weights_prior.precision.to_dense()
151+
else:
152+
weights_prior_precision = weights_prior.precision()
150153
elif weights_prior is not None:
151154
inverse_scale = weights_prior.scale.inverse()
152155
weights_prior_precision = inverse_scale.matmul(
@@ -840,7 +843,6 @@ def _build_sampler_loop_body(model,
840843
sampler = dynamic_spike_and_slab.DynamicSpikeSlabSampler
841844
else:
842845
sampler = spike_and_slab.SpikeSlabSampler
843-
844846
spike_and_slab_sampler = sampler(
845847
design_matrix,
846848
weights_prior_precision=regression_component._weights_prior_precision, # pylint: disable=protected-access

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030

3131
tfd = tfp.distributions
32+
tfde = tfp.experimental.distributions
3233
tfl = tf.linalg
3334

3435
JAX_MODE = False
@@ -398,6 +399,18 @@ def test_invalid_model_spec_raises_error(self):
398399
level_variance_prior=tfd.InverseGamma(0.01, 0.01),
399400
observation_noise_variance_prior=tfd.LogNormal(0., 3.))
400401

402+
def test_model_with_linop_precision_works(self):
403+
observed_time_series = tf.ones([2])
404+
design_matrix = tf.eye(2)
405+
sampler = gibbs_sampler.build_model_for_gibbs_fitting(
406+
observed_time_series,
407+
design_matrix=design_matrix,
408+
weights_prior=tfde.MultivariateNormalPrecisionFactorLinearOperator(
409+
precision_factor=tf.linalg.LinearOperatorDiag(tf.ones(2))),
410+
level_variance_prior=tfd.InverseGamma(0.01, 0.01),
411+
observation_noise_variance_prior=tfd.InverseGamma(0.01, 0.01))
412+
self.assertIsNotNone(sampler)
413+
401414
def test_invalid_options_with_none_design_matrix_raises_error(self):
402415
observed_time_series = tf.ones([2])
403416
with self.assertRaisesRegex(

0 commit comments

Comments
 (0)