18
18
from __future__ import division
19
19
from __future__ import print_function
20
20
21
+ import numpy as onp
22
+
21
23
import tensorflow .compat .v2 as tf
22
24
23
25
from tensorflow_probability .python import math as tfp_math
@@ -47,6 +49,8 @@ def __init__(self,
47
49
loc ,
48
50
scale ,
49
51
num_probit_terms_approx = 2 ,
52
+ gauss_hermite_scale_limit = None ,
53
+ gauss_hermite_degree = 20 ,
50
54
validate_args = False ,
51
55
allow_nan_stats = True ,
52
56
name = 'LogitNormal' ):
@@ -71,6 +75,18 @@ def __init__(self,
71
75
(inclusive). Using `num_probit_terms_approx=2` should result in
72
76
`mean_approx` error not exceeding `10**-4`.
73
77
Default value: `2`.
78
+ gauss_hermite_scale_limit: Floating-point `Tensor` or `None`.
79
+ The (batch-wise) maximum scale at which to compute statistics
80
+ with Gauss-Hermite quadrature instead of the Monahan-Stefanski
81
+ approximation [1]. Default: `None`, which recovers the legacy
82
+ behavior of using Monahan-Stefanski everywhere and does not
83
+ add TF ops for Gauss-Hermite. The best value depends on the
84
+ working precision and the number of terms in the Gauss-Hermite
85
+ or Monahan-Stefanski approximations being switched between,
86
+ as well as the expected range of `loc` parameters; but `1` is
87
+ not unreasonable.
88
+ gauss_hermite_degree: Python integer giving the number of
89
+ sample points to use for Gauss-Hermite quadrature.
74
90
validate_args: Python `bool`, default `False`. Whether to validate input
75
91
with asserts. If `validate_args` is `False`, and the inputs are
76
92
invalid, correct behavior is not guaranteed.
@@ -90,6 +106,7 @@ def __init__(self,
90
106
Communications in Statistics-Simulation and Computation 9.4 (1980):
91
107
389-419.
92
108
https://www.tandfonline.com/doi/abs/10.1080/03610918008812164
109
+
93
110
"""
94
111
parameters = dict (locals ())
95
112
num_probit_terms_approx = int (num_probit_terms_approx )
@@ -98,6 +115,8 @@ def __init__(self,
98
115
'Argument `num_probit_terms_approx` must be an integer between '
99
116
'`1` and `8` (inclusive).' )
100
117
self ._num_probit_terms_approx = num_probit_terms_approx
118
+ self ._gauss_hermite_scale_limit = gauss_hermite_scale_limit
119
+ self ._gauss_hermite_degree = gauss_hermite_degree
101
120
with tf .name_scope (name ) as name :
102
121
super (LogitNormal , self ).__init__ (
103
122
distribution = normal_lib .Normal (loc = loc , scale = scale ),
@@ -131,6 +150,16 @@ def num_probit_terms_approx(self):
131
150
"""Number of `Normal(0, 1).cdf` terms using in `mean_*_approx` functions."""
132
151
return self ._num_probit_terms_approx
133
152
153
+ @property
154
+ def gauss_hermite_scale_limit (self ):
155
+ """Largest scale using Gauss-Hermite quadrature in `*_approx` functions."""
156
+ return self ._gauss_hermite_scale_limit
157
+
158
+ @property
159
+ def gauss_hermite_degree (self ):
160
+ """Number of points for Gauss-Hermite quadrature in `*_approx` functions."""
161
+ return self ._gauss_hermite_degree
162
+
134
163
experimental_is_sharded = False
135
164
136
165
def mean_log_prob_approx (self , y = None , name = 'mean_log_prob_approx' ):
@@ -199,10 +228,19 @@ def mean_approx(self, name='mean_approx'):
199
228
https://www.tandfonline.com/doi/abs/10.1080/03610918008812164
200
229
"""
201
230
with self ._name_and_control_scope (name ):
202
- return approx_expected_sigmoid (
203
- self .loc , self .scale ,
231
+ loc = tf .convert_to_tensor (self .loc )
232
+ scale = tf .convert_to_tensor (self .scale )
233
+ monahan_stefanski_answer = approx_expected_sigmoid (
234
+ loc , scale ,
204
235
MONAHAN_MIX_PROB [self .num_probit_terms_approx ],
205
236
MONAHAN_INVERSE_SCALE [self .num_probit_terms_approx ])
237
+ if self .gauss_hermite_scale_limit is None :
238
+ return monahan_stefanski_answer
239
+ else :
240
+ gauss_hermite_answer = logit_normal_mean_gh (
241
+ loc , scale , self .gauss_hermite_degree )
242
+ return tf .where (scale < self .gauss_hermite_scale_limit ,
243
+ gauss_hermite_answer , monahan_stefanski_answer )
206
244
207
245
def variance_approx (self , name = 'variance_approx' ):
208
246
"""Approximate the variance of a LogitNormal.
@@ -233,10 +271,19 @@ def variance_approx(self, name='variance_approx'):
233
271
https://www.tandfonline.com/doi/abs/10.1080/03610918008812164
234
272
"""
235
273
with self ._name_and_control_scope (name ):
236
- return approx_variance_sigmoid (
237
- self .loc , self .scale ,
274
+ loc = tf .convert_to_tensor (self .loc )
275
+ scale = tf .convert_to_tensor (self .scale )
276
+ monahan_stefanski_answer = approx_variance_sigmoid (
277
+ loc , scale ,
238
278
MONAHAN_MIX_PROB [self .num_probit_terms_approx ],
239
279
MONAHAN_INVERSE_SCALE [self .num_probit_terms_approx ])
280
+ if self .gauss_hermite_scale_limit is None :
281
+ return monahan_stefanski_answer
282
+ else :
283
+ gauss_hermite_answer = logit_normal_variance_gh (
284
+ loc , scale , self .gauss_hermite_degree )
285
+ return tf .where (scale < self .gauss_hermite_scale_limit ,
286
+ gauss_hermite_answer , monahan_stefanski_answer )
240
287
241
288
def stddev_approx (self , name = 'stddev_approx' ):
242
289
"""Approximate the stdandard deviation of a LogitNormal.
@@ -479,3 +526,37 @@ def approx_variance_sigmoid(
479
526
alpha [tf .newaxis , :] * alpha [:, tf .newaxis ] * (b + bt ),
480
527
axis = [- 2 , - 1 ])
481
528
return mom2 - approx_expected_sigmoid (m , s , alpha , c )** 2.
529
+
530
+
531
+ # The above approximations fail for small scales. We compute
532
+ # statistics for small scales with Gauss-Hermite quadrature.
533
+
534
+
535
+ def logit_normal_mean_gh (loc , scale , deg ):
536
+ """Approximates `E_{N(m,s)}[sigmoid(X)]` by Gauss-Hermite quadrature."""
537
+ # We want to integrate
538
+ # A = \int_-inf^inf sigmoid(x) * Normal(loc, scale).pdf(x) dx
539
+ # To bring it into the right form for Gauss-Hermite quadrature,
540
+ # we make the substitution y = (x - loc) / scale, to get
541
+ # A = (1/sqrt(2*pi)) * \int_-inf^inf [
542
+ # sigmoid(y * scale + loc) * exp(-1/2 y**2) dy]
543
+ grid , weights = onp .polynomial .hermite_e .hermegauss (deg )
544
+ grid = tf .cast (grid , dtype = loc .dtype )
545
+ weights = tf .cast (weights , dtype = loc .dtype )
546
+ normalizer = tf .constant (onp .sqrt (2 * onp .pi ), dtype = loc .dtype )
547
+ values = tf .sigmoid (grid * scale [..., tf .newaxis ] + loc [..., tf .newaxis ])
548
+ return tf .reduce_sum (values * weights , axis = - 1 ) / normalizer
549
+
550
+
551
+ def logit_normal_variance_gh (loc , scale , deg ):
552
+ """Approxmates `Var_{N(m,s)}[sigmoid(X)]` by Gauss-Hermite quadrature."""
553
+ # Since we have to compute sigmoids for variance anyway, we inline
554
+ # computing the mean by Gauss-Hermite quadrature at the same grid of points.
555
+ grid , weights = onp .polynomial .hermite_e .hermegauss (deg )
556
+ grid = tf .cast (grid , dtype = loc .dtype )
557
+ weights = tf .cast (weights , dtype = loc .dtype )
558
+ normalizer = tf .constant (onp .sqrt (2 * onp .pi ), dtype = loc .dtype )
559
+ sigmoids = tf .sigmoid (grid * scale [..., tf .newaxis ] + loc [..., tf .newaxis ])
560
+ mean = tf .reduce_sum (sigmoids * weights , axis = - 1 ) / normalizer
561
+ residuals = (sigmoids - mean [..., tf .newaxis ])** 2
562
+ return tf .reduce_sum (residuals * weights , axis = - 1 ) / normalizer
0 commit comments