Skip to content

Commit 9d88c41

Browse files
axchtensorflower-gardener
authored andcommitted
Implement optional Gauss-Hermite quadrature for computing statistics of the LogitNormal distribution.
PiperOrigin-RevId: 383682512
1 parent d4609c2 commit 9d88c41

File tree

2 files changed

+149
-4
lines changed

2 files changed

+149
-4
lines changed

tensorflow_probability/python/distributions/logitnormal.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import numpy as onp
22+
2123
import tensorflow.compat.v2 as tf
2224

2325
from tensorflow_probability.python import math as tfp_math
@@ -47,6 +49,8 @@ def __init__(self,
4749
loc,
4850
scale,
4951
num_probit_terms_approx=2,
52+
gauss_hermite_scale_limit=None,
53+
gauss_hermite_degree=20,
5054
validate_args=False,
5155
allow_nan_stats=True,
5256
name='LogitNormal'):
@@ -71,6 +75,18 @@ def __init__(self,
7175
(inclusive). Using `num_probit_terms_approx=2` should result in
7276
`mean_approx` error not exceeding `10**-4`.
7377
Default value: `2`.
78+
gauss_hermite_scale_limit: Floating-point `Tensor` or `None`.
79+
The (batch-wise) maximum scale at which to compute statistics
80+
with Gauss-Hermite quadrature instead of the Monahan-Stefanski
81+
approximation [1]. Default: `None`, which recovers the legacy
82+
behavior of using Monahan-Stefanski everywhere and does not
83+
add TF ops for Gauss-Hermite. The best value depends on the
84+
working precision and the number of terms in the Gauss-Hermite
85+
or Monahan-Stefanski approximations being switched between,
86+
as well as the expected range of `loc` parameters; but `1` is
87+
not unreasonable.
88+
gauss_hermite_degree: Python integer giving the number of
89+
sample points to use for Gauss-Hermite quadrature.
7490
validate_args: Python `bool`, default `False`. Whether to validate input
7591
with asserts. If `validate_args` is `False`, and the inputs are
7692
invalid, correct behavior is not guaranteed.
@@ -90,6 +106,7 @@ def __init__(self,
90106
Communications in Statistics-Simulation and Computation 9.4 (1980):
91107
389-419.
92108
https://www.tandfonline.com/doi/abs/10.1080/03610918008812164
109+
93110
"""
94111
parameters = dict(locals())
95112
num_probit_terms_approx = int(num_probit_terms_approx)
@@ -98,6 +115,8 @@ def __init__(self,
98115
'Argument `num_probit_terms_approx` must be an integer between '
99116
'`1` and `8` (inclusive).')
100117
self._num_probit_terms_approx = num_probit_terms_approx
118+
self._gauss_hermite_scale_limit = gauss_hermite_scale_limit
119+
self._gauss_hermite_degree = gauss_hermite_degree
101120
with tf.name_scope(name) as name:
102121
super(LogitNormal, self).__init__(
103122
distribution=normal_lib.Normal(loc=loc, scale=scale),
@@ -131,6 +150,16 @@ def num_probit_terms_approx(self):
131150
"""Number of `Normal(0, 1).cdf` terms using in `mean_*_approx` functions."""
132151
return self._num_probit_terms_approx
133152

153+
@property
154+
def gauss_hermite_scale_limit(self):
155+
"""Largest scale using Gauss-Hermite quadrature in `*_approx` functions."""
156+
return self._gauss_hermite_scale_limit
157+
158+
@property
159+
def gauss_hermite_degree(self):
160+
"""Number of points for Gauss-Hermite quadrature in `*_approx` functions."""
161+
return self._gauss_hermite_degree
162+
134163
experimental_is_sharded = False
135164

136165
def mean_log_prob_approx(self, y=None, name='mean_log_prob_approx'):
@@ -199,10 +228,19 @@ def mean_approx(self, name='mean_approx'):
199228
https://www.tandfonline.com/doi/abs/10.1080/03610918008812164
200229
"""
201230
with self._name_and_control_scope(name):
202-
return approx_expected_sigmoid(
203-
self.loc, self.scale,
231+
loc = tf.convert_to_tensor(self.loc)
232+
scale = tf.convert_to_tensor(self.scale)
233+
monahan_stefanski_answer = approx_expected_sigmoid(
234+
loc, scale,
204235
MONAHAN_MIX_PROB[self.num_probit_terms_approx],
205236
MONAHAN_INVERSE_SCALE[self.num_probit_terms_approx])
237+
if self.gauss_hermite_scale_limit is None:
238+
return monahan_stefanski_answer
239+
else:
240+
gauss_hermite_answer = logit_normal_mean_gh(
241+
loc, scale, self.gauss_hermite_degree)
242+
return tf.where(scale < self.gauss_hermite_scale_limit,
243+
gauss_hermite_answer, monahan_stefanski_answer)
206244

207245
def variance_approx(self, name='variance_approx'):
208246
"""Approximate the variance of a LogitNormal.
@@ -233,10 +271,19 @@ def variance_approx(self, name='variance_approx'):
233271
https://www.tandfonline.com/doi/abs/10.1080/03610918008812164
234272
"""
235273
with self._name_and_control_scope(name):
236-
return approx_variance_sigmoid(
237-
self.loc, self.scale,
274+
loc = tf.convert_to_tensor(self.loc)
275+
scale = tf.convert_to_tensor(self.scale)
276+
monahan_stefanski_answer = approx_variance_sigmoid(
277+
loc, scale,
238278
MONAHAN_MIX_PROB[self.num_probit_terms_approx],
239279
MONAHAN_INVERSE_SCALE[self.num_probit_terms_approx])
280+
if self.gauss_hermite_scale_limit is None:
281+
return monahan_stefanski_answer
282+
else:
283+
gauss_hermite_answer = logit_normal_variance_gh(
284+
loc, scale, self.gauss_hermite_degree)
285+
return tf.where(scale < self.gauss_hermite_scale_limit,
286+
gauss_hermite_answer, monahan_stefanski_answer)
240287

241288
def stddev_approx(self, name='stddev_approx'):
242289
"""Approximate the stdandard deviation of a LogitNormal.
@@ -479,3 +526,37 @@ def approx_variance_sigmoid(
479526
alpha[tf.newaxis, :] * alpha[:, tf.newaxis] * (b + bt),
480527
axis=[-2, -1])
481528
return mom2 - approx_expected_sigmoid(m, s, alpha, c)**2.
529+
530+
531+
# The above approximations fail for small scales. We compute
532+
# statistics for small scales with Gauss-Hermite quadrature.
533+
534+
535+
def logit_normal_mean_gh(loc, scale, deg):
536+
"""Approximates `E_{N(m,s)}[sigmoid(X)]` by Gauss-Hermite quadrature."""
537+
# We want to integrate
538+
# A = \int_-inf^inf sigmoid(x) * Normal(loc, scale).pdf(x) dx
539+
# To bring it into the right form for Gauss-Hermite quadrature,
540+
# we make the substitution y = (x - loc) / scale, to get
541+
# A = (1/sqrt(2*pi)) * \int_-inf^inf [
542+
# sigmoid(y * scale + loc) * exp(-1/2 y**2) dy]
543+
grid, weights = onp.polynomial.hermite_e.hermegauss(deg)
544+
grid = tf.cast(grid, dtype=loc.dtype)
545+
weights = tf.cast(weights, dtype=loc.dtype)
546+
normalizer = tf.constant(onp.sqrt(2 * onp.pi), dtype=loc.dtype)
547+
values = tf.sigmoid(grid * scale[..., tf.newaxis] + loc[..., tf.newaxis])
548+
return tf.reduce_sum(values * weights, axis=-1) / normalizer
549+
550+
551+
def logit_normal_variance_gh(loc, scale, deg):
552+
"""Approxmates `Var_{N(m,s)}[sigmoid(X)]` by Gauss-Hermite quadrature."""
553+
# Since we have to compute sigmoids for variance anyway, we inline
554+
# computing the mean by Gauss-Hermite quadrature at the same grid of points.
555+
grid, weights = onp.polynomial.hermite_e.hermegauss(deg)
556+
grid = tf.cast(grid, dtype=loc.dtype)
557+
weights = tf.cast(weights, dtype=loc.dtype)
558+
normalizer = tf.constant(onp.sqrt(2 * onp.pi), dtype=loc.dtype)
559+
sigmoids = tf.sigmoid(grid * scale[..., tf.newaxis] + loc[..., tf.newaxis])
560+
mean = tf.reduce_sum(sigmoids * weights, axis=-1) / normalizer
561+
residuals = (sigmoids - mean[..., tf.newaxis])**2
562+
return tf.reduce_sum(residuals * weights, axis=-1) / normalizer

tensorflow_probability/python/distributions/logitnormal_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,42 @@
2828

2929

3030
tfd = tfp.distributions
31+
ln_lib = tfd.logitnormal
32+
33+
34+
def logit_normal_trapezoid_rule(loc, scale):
35+
"""Brute-force statistics of LogitNormal(loc, scale) by quadrature."""
36+
# LogitNormal samples as
37+
# z ~ Normal(loc, scale)
38+
# return sigmoid(z)
39+
# We find the statistics by integrating f(z) * Normal.pdf(z) over z.
40+
# The function f is always bounded, and for z outside +-10 * scale,
41+
# the Normal cdf is small enough to be negligible. Thus it suffices
42+
# to integrate from loc - 10 * scale to loc + 10 * scale
43+
n = 10000
44+
width = 10.0
45+
xs = tf.linspace(loc - width*scale, loc + width*scale, n)
46+
def trapezoid(vals):
47+
total = tf.reduce_sum(vals, axis=0) - 0.5 * (vals[0] + vals[-1])
48+
return total * 2 * width * scale / tf.cast((n-1), xs.dtype)
49+
return xs, trapezoid
50+
51+
52+
def logit_normal_mean_trapezoid(loc, scale):
53+
"""Brute-force the mean of LogitNormal(loc, scale) by quadrature."""
54+
dist = tfd.Normal(loc, scale)
55+
grid, compute = logit_normal_trapezoid_rule(loc, scale)
56+
return compute(tf.sigmoid(grid) * dist.prob(grid))
57+
58+
59+
def logit_normal_variance_trapezoid(loc, scale):
60+
"""Brute-force the variance of LogitNormal(loc, scale) by quadrature."""
61+
dist = tfd.Normal(loc, scale)
62+
grid, compute = logit_normal_trapezoid_rule(loc, scale)
63+
probs = dist.prob(grid)
64+
sigmoids = tf.sigmoid(grid)
65+
mean = compute(sigmoids * probs)
66+
return compute((sigmoids - mean)**2 * probs)
3167

3268

3369
@test_util.test_all_tf_execution_regimes
@@ -69,6 +105,34 @@ def testLogitNormalVarianceApprox(self):
69105
self.assertAllClose(
70106
variance_sample_, variance_approx_, atol=1e-4, rtol=0.03)
71107

108+
def testLogitNormalMeanGH(self):
109+
locs, scales = tf.meshgrid(tf.linspace(-10.0, 10.0, 10),
110+
tf.exp(tf.linspace(-3.0, 0.0, 10)))
111+
ghs = ln_lib.logit_normal_mean_gh(locs, scales, deg=50)
112+
traps = logit_normal_mean_trapezoid(locs, scales)
113+
self.assertAllClose(traps, ghs, rtol=1e-4)
114+
115+
def testLogitNormalVarianceGH(self):
116+
locs, scales = tf.meshgrid(tf.linspace(-10.0, 10.0, 10),
117+
tf.exp(tf.linspace(-3.0, 0.0, 10)))
118+
ghs = ln_lib.logit_normal_variance_gh(locs, scales, deg=50)
119+
traps = logit_normal_variance_trapezoid(locs, scales)
120+
self.assertAllClose(traps, ghs, rtol=1e-4)
121+
122+
def testLogitNormalMeanAndVariance(self):
123+
locs, scales = tf.meshgrid(tf.linspace(-10.0, 10.0, 10),
124+
tf.exp(tf.linspace(-3.0, 3.0, 10)))
125+
dist = tfd.LogitNormal(
126+
loc=locs, scale=scales, validate_args=True,
127+
gauss_hermite_scale_limit=1.,
128+
num_probit_terms_approx=6)
129+
means = dist.mean_approx()
130+
trap_means = logit_normal_mean_trapezoid(locs, scales)
131+
self.assertAllClose(trap_means, means, rtol=1e-4)
132+
variances = dist.variance_approx()
133+
trap_variances = logit_normal_variance_trapezoid(locs, scales)
134+
self.assertAllClose(trap_variances, variances, rtol=1e-4)
135+
72136
def testLogitNormalLogitNormalKL(self):
73137
batch_size = 6
74138
mu_a = np.array([3.0] * batch_size)

0 commit comments

Comments
 (0)