1616
1717import collections
1818
19- import numpy as np
2019import tensorflow .compat .v2 as tf
2120
2221from tensorflow_probability .python import math as tfp_math
@@ -53,6 +52,26 @@ def _sample_n(self, n, seed=None):
5352 return xs
5453
5554
55+ class MVNPrecisionFactorHardZeros (
56+ MultivariateNormalPrecisionFactorLinearOperator ):
57+ """Multivariate normal that forces some sample dimensions to zero.
58+
59+ This is equivalent to setting `loc[d] = 0.` and `precision_factor[d, d]=`inf`
60+ in the zeroed dimensions, but is numerically better behaved.
61+ """
62+
63+ def __init__ (self , loc , precision_factor , nonzeros , ** kwargs ):
64+ self ._nonzeros = nonzeros
65+ super ().__init__ (loc = loc , precision_factor = precision_factor , ** kwargs )
66+
67+ def _call_sample_n (self , * args , ** kwargs ):
68+ xs = super ()._call_sample_n (* args , ** kwargs )
69+ return tf .where (self ._nonzeros , xs , 0. )
70+
71+ def _log_prob (self , * args , ** kwargs ):
72+ raise NotImplementedError ('Log prob is not currently implemented.' )
73+
74+
5675class SpikeSlabSamplerState (collections .namedtuple (
5776 'SpikeSlabSamplerState' ,
5877 ['x_transpose_y' ,
@@ -513,14 +532,6 @@ def _compute_log_prob(
513532
514533 def _get_conditional_posterior (self , sampler_state ):
515534 """Builds the joint posterior for a sparsity pattern (eqn (7) from [1])."""
516- # Impose a hard, infinite-precision constraint on zeroed-out features, in
517- # place of the identity-matrix representation that we used for numerical
518- # convenience during sampling.
519- hard_precision_factor = _select_nonzero_block (
520- sampler_state .conditional_posterior_precision_chol ,
521- nonzeros = sampler_state .nonzeros ,
522- identity_multiplier = np .inf )
523-
524535 @joint_distribution_auto_batched .JointDistributionCoroutineAutoBatched
525536 def posterior_jd ():
526537 observation_noise_variance = yield InverseGammaWithSampleUpperBound (
@@ -529,20 +540,21 @@ def posterior_jd():
529540 scale = sampler_state .observation_noise_variance_posterior_scale ,
530541 upper_bound = self .observation_noise_variance_upper_bound ,
531542 name = 'observation_noise_variance' )
532- yield MultivariateNormalPrecisionFactorLinearOperator (
543+ yield MVNPrecisionFactorHardZeros (
533544 loc = sampler_state .conditional_weights_mean ,
534545 # Note that the posterior precision varies inversely with the
535546 # noise variance: in worlds with high noise we're also
536547 # more uncertain about the values of the weights.
537548 precision_factor = tf .linalg .LinearOperatorLowerTriangular (
538- hard_precision_factor /
549+ sampler_state . conditional_posterior_precision_chol /
539550 observation_noise_variance [..., tf .newaxis , tf .newaxis ]),
551+ nonzeros = sampler_state .nonzeros ,
540552 name = 'weights' )
541553
542554 return posterior_jd
543555
544556
545- def _select_nonzero_block (matrix , nonzeros , identity_multiplier = 1. ):
557+ def _select_nonzero_block (matrix , nonzeros ):
546558 """Replaces the `i`th row & col with the identity if not `nonzeros[i]`.
547559
548560 This function effectively selects the 'slab' rows (corresponding to
@@ -566,7 +578,6 @@ def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
566578 matrix: (batch of) float Tensor matrix(s) of shape
567579 `[num_features, num_features]`.
568580 nonzeros: (batch of) boolean Tensor vectors of shape `[num_features]`.
569- identity_multiplier: optional scalar multiplier for the identity matrix.
570581 Returns:
571582 block_matrix: (batch of) float Tensor matrix(s) of the same shape as
572583 `matrix`, in which `block_matrix[i, j] = matrix[i, j] if
@@ -578,13 +589,10 @@ def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
578589 masked = tf .where (nonzeros [..., tf .newaxis ],
579590 tf .where (nonzeros [..., tf .newaxis , :], matrix , 0. ),
580591 0. )
581- # Restore a value of `identity_multiplier` on the diagonal of the not-selected
582- # rows. This avoids numerical issues by ensuring that the matrix still has
583- # full rank.
592+ # Restore a value of 1 on the diagonal of the not-selected rows. This avoids
593+ # numerical issues by ensuring that the matrix still has full rank.
584594 return tf .linalg .set_diag (masked ,
585- tf .where (nonzeros ,
586- tf .linalg .diag_part (masked ),
587- identity_multiplier ))
595+ tf .where (nonzeros , tf .linalg .diag_part (masked ), 1. ))
588596
589597
590598def _update_nonzero_block_chol (chol , idx , psd_matrix , new_nonzeros ):
0 commit comments