File tree Expand file tree Collapse file tree 2 files changed +18
-3
lines changed
tensorflow_probability/python/experimental/sts_gibbs Expand file tree Collapse file tree 2 files changed +18
-3
lines changed Original file line number Diff line number Diff line change @@ -140,13 +140,16 @@ class SpikeAndSlabSparseLinearRegression(sts_components.LinearRegression):
140
140
141
141
def __init__ (self ,
142
142
design_matrix ,
143
- weights_prior = None ,
143
+ weights_prior ,
144
144
sparse_weights_nonzero_prob = 0.5 ,
145
145
name = None ):
146
146
# Extract precision matrix from a multivariate normal prior.
147
147
weights_prior_precision = None
148
148
if hasattr (weights_prior , 'precision' ):
149
- weights_prior_precision = weights_prior .precision ()
149
+ if isinstance (weights_prior .precision , tf .linalg .LinearOperator ):
150
+ weights_prior_precision = weights_prior .precision .to_dense ()
151
+ else :
152
+ weights_prior_precision = weights_prior .precision ()
150
153
elif weights_prior is not None :
151
154
inverse_scale = weights_prior .scale .inverse ()
152
155
weights_prior_precision = inverse_scale .matmul (
@@ -840,7 +843,6 @@ def _build_sampler_loop_body(model,
840
843
sampler = dynamic_spike_and_slab .DynamicSpikeSlabSampler
841
844
else :
842
845
sampler = spike_and_slab .SpikeSlabSampler
843
-
844
846
spike_and_slab_sampler = sampler (
845
847
design_matrix ,
846
848
weights_prior_precision = regression_component ._weights_prior_precision , # pylint: disable=protected-access
Original file line number Diff line number Diff line change 29
29
30
30
31
31
tfd = tfp .distributions
32
+ tfde = tfp .experimental .distributions
32
33
tfl = tf .linalg
33
34
34
35
JAX_MODE = False
@@ -398,6 +399,18 @@ def test_invalid_model_spec_raises_error(self):
398
399
level_variance_prior = tfd .InverseGamma (0.01 , 0.01 ),
399
400
observation_noise_variance_prior = tfd .LogNormal (0. , 3. ))
400
401
402
+ def test_model_with_linop_precision_works (self ):
403
+ observed_time_series = tf .ones ([2 ])
404
+ design_matrix = tf .eye (2 )
405
+ sampler = gibbs_sampler .build_model_for_gibbs_fitting (
406
+ observed_time_series ,
407
+ design_matrix = design_matrix ,
408
+ weights_prior = tfde .MultivariateNormalPrecisionFactorLinearOperator (
409
+ precision_factor = tf .linalg .LinearOperatorDiag (tf .ones (2 ))),
410
+ level_variance_prior = tfd .InverseGamma (0.01 , 0.01 ),
411
+ observation_noise_variance_prior = tfd .InverseGamma (0.01 , 0.01 ))
412
+ self .assertIsNotNone (sampler )
413
+
401
414
def test_invalid_options_with_none_design_matrix_raises_error (self ):
402
415
observed_time_series = tf .ones ([2 ])
403
416
with self .assertRaisesRegex (
You can’t perform that action at this time.
0 commit comments