11using Distributions, Random
22using Statistics
33using SpecialFunctions
4- const __TRUNC = 0.64 ;
5- const __TRUNC_RECIP = 1.0 / __TRUNC;
4+ const pg_t = 0.64
5+ const pg_inv_t = inv (pg_t)
6+
67"""
7- PolyaGamma(b::Int , c::Real)
8+ PolyaGamma(b::Real , c::Real)
89
910## Arguments
10- - `b::Int `
11+ - `b::Real `
1112- `c::Real` exponential tilting
1213
1314## Keyword Arguments
@@ -16,151 +17,161 @@ const __TRUNC_RECIP = 1.0 / __TRUNC;
1617
1718Create a PolyaGamma sampler with parameters `b` and `c`
1819"""
19- struct PolyaGamma{Tc,A } <: Distributions.ContinuousUnivariateDistribution
20+ struct PolyaGamma{Tb,Tc } <: Distributions.ContinuousUnivariateDistribution
2021 # For sum of Gammas.
21- b:: Int
22+ b:: Tb
2223 c:: Tc
23- trunc:: Int
24- nmax:: Int
25- bvec:: A
26- # Constructor
27- function PolyaGamma {T} (b:: Int , c:: T , trunc:: Int , nmax:: Int ) where {T<: Real }
28- if trunc < 1
29- @warn " trunc < 1. Setting trunc=1."
30- trunc = 1
31- end
32- bvec = [convert (T, (twoπ * (k - 0.5 ))^ 2 ) for k in 1 : trunc]
33- return new {typeof(c),typeof(bvec)} (b, c, trunc, nmax, bvec)
34- end
3524end
3625
26+ Base. eltype (:: PolyaGamma{T,Tc} ) where {T,Tc} = Tc
27+
28+ Distributions. params (d:: PolyaGamma ) = (d. b, d. c)
29+
3730Statistics. mean (d:: PolyaGamma ) = d. b / (2 * d. c) * tanh (d. c / 2 )
3831
39- function PolyaGamma (b:: Int , c:: T ; nmax:: Int = 10 , trunc:: Int = 200 ) where {T<: Real }
40- return PolyaGamma {T} (b, c, trunc, nmax)
32+ Base. minimum (d:: PolyaGamma ) = zero (eltype (d))
33+ Base. maximum (:: PolyaGamma ) = Inf
34+ Distributions. insupport (:: PolyaGamma , x:: Real ) = zero (x) <= x < Inf
35+
36+ function Distributions. pdf (d:: PolyaGamma , x:: Real )
37+ b, c = params (d)
38+ iszero (x) && return zero (x)
39+ return _tilt (x, b, c) * 2 ^ (b - 1 ) / gamma (b) * sum (0 : 200 ) do n
40+ ifelse (iseven (n), 1 , - 1 ) * exp (
41+ loggamma (n + b) - loggamma (n + 1 ) + log (2 n + b) - log (twoπ * x^ 3 ) / 2 -
42+ (2 n + b)^ 2 / (8 x),
43+ )
44+ end
4145end
4246
43- function Distributions. pdf (d:: PolyaGamma , x)
44- return cosh (d. c / 2 )^ d. b * 2.0 ^ (d. b - 1 ) / gamma (d. b) * sum (
45- ((- 1 )^ n) * gamma (n + d. b) / gamma (n + 1 ) * (2 * n + b) / (sqrt (2 * π * x^ 3 )) *
46- exp (- (2 * n + b)^ 2 / (8 * x) - c^ 2 / 2 * x) for n in 0 : (d. nmax)
47- )
47+ function _tilt (ω, b, c)
48+ return cosh (c / 2 )^ b * exp (- c^ 2 / 2 * ω)
4849end
4950
50- # # Sampling
51- function Distributions. rand (rng:: AbstractRNG , d:: PolyaGamma{T} ) where {T<: Real }
51+ function Distributions. rand (rng:: AbstractRNG , d:: PolyaGamma )
5252 if iszero (d. b)
53- return zero (T)
53+ return zero (eltype (d))
54+ end
55+ return draw_sum (rng, d)
56+ end
57+
58+ # # Sampling when `b` is an integer
59+ function draw_sum (rng:: AbstractRNG , d:: PolyaGamma{<:Int} )
60+ return sum (Base. Fix1 (sample_pg1, rng), d. c * ones (d. b))
61+ end
62+
63+ function draw_sum (rng:: AbstractRNG , d:: PolyaGamma{<:Real} )
64+ if d. b < 1
65+ return rand_gamma_sum (rng, d, d. b)
5466 end
55- return sum (Base. Fix1 (draw_like_devroye, rng), d. c * ones (d. b))
67+ trunc_b = floor (Int, d. b)
68+ res_b = d. b - trunc_b
69+ trunc_term = sum (Base. Fix1 (sample_pg1, rng), d. c * ones (trunc_b))
70+ res_term = rand_gamma_sum (rng, d, res_b)
71+ return trunc_term + res_term
5672end
5773
5874# # Utility functions
5975function a (n:: Int , x:: Real )
6076 k = (n + 0.5 ) * π
61- if x > __TRUNC
77+ if x > pg_t
6278 return k * exp (- k^ 2 * x / 2 )
6379 elseif x > 0
64- expnt = - 1.5 * (log (π / 2 ) + log (x)) + log (k) - 2 * (n + 0.5 )^ 2 / x
80+ expnt = - 3 / 2 * (log (halfπ ) + log (x)) + log (k) - 2 * (n + 1 // 2 )^ 2 / x
6581 return exp (expnt)
82+ else
83+ error (" x should be a positive real" )
6684 end
6785end
6886
6987function mass_texpon (z:: Real )
70- t = __TRUNC
88+ t = pg_t
7189
72- fz = 0.125 * π^ 2 + z^ 2 / 2
73- b = sqrt (1.0 / t ) * (t * z - 1 )
74- a = sqrt (1.0 / t) * (t * z + 1 ) * - 1.0
90+ K = π^ 2 / 8 + z^ 2 / 2
91+ b = sqrt (inv (t) ) * (t * z - 1 )
92+ a = - sqrt (inv (t)) * (t * z + 1 )
7593
76- x0 = log (fz ) + fz * t
94+ x0 = log (K ) + K * t
7795 xb = x0 - z + logcdf (Distributions. Normal (), b)
7896 xa = x0 + z + logcdf (Distributions. Normal (), a)
7997
80- qdivp = 4 / π * (exp (xb) + exp (xa))
98+ qdivp = fourinvπ * (exp (xb) + exp (xa))
8199
82- return 1.0 / (1.0 + qdivp)
100+ return 1 / (1 + qdivp)
83101end
84102
85- function rtigauss (rng:: AbstractRNG , z:: Real )
86- z = abs (z)
87- t = __TRUNC
88- x = t + 1.0
89- if __TRUNC_RECIP > z
90- alpha = 0.0
91- rate = 1.0
92- d_exp = Exponential (1.0 / rate)
93- while (rand (rng) > alpha)
94- e1 = rand (rng, d_exp)
95- e2 = rand (rng, d_exp)
96- while e1^ 2 > 2 * e2 / t
97- e1 = rand (rng, d_exp)
98- e2 = rand (rng, d_exp)
103+ # Sample from a truncated inverse gaussian
104+ function rand_truncated_inverse_gaussian (rng:: AbstractRNG , z:: Real )
105+ μ = inv (z)
106+ x = one (z) + pg_t
107+ if μ > pg_t
108+ d_exp = Exponential ()
109+ while true
110+ E = rand (rng, d_exp)
111+ E′ = rand (rng, d_exp)
112+ while E^ 2 > 2 E′ / pg_t
113+ E = rand (rng, d_exp)
114+ E′ = rand (rng, d_exp)
99115 end
100- x = 1 + e1 * t
101- x = t / x ^ 2
102- alpha = exp ( - z ^ 2 * x / 2 )
116+ x = pg_t / ( 1 + E * pg_t) ^ 2
117+ α = exp ( - z ^ 2 * x / 2 )
118+ α >= rand (rng) && break
103119 end
104120 else
105- mu = 1.0 / z
106- while (x > t)
107- y = randn (rng)^ 2
108- half_mu = mu / 2
109- mu_Y = mu * y
110- x = mu + half_mu * mu_Y - half_mu * sqrt (4 * mu_Y + mu_Y^ 2 )
111- if rand (rng) > mu / (mu + x)
112- x = mu^ 2 / x
121+ while (x > pg_t)
122+ Y = randn (rng)^ 2
123+ μY = μ * Y
124+ x = μ + μ * μY / 2 - μ / 2 * sqrt (4 * μY + μY^ 2 )
125+ if rand (rng) > μ / (μ + x)
126+ x = μ^ 2 / x
113127 end
128+ x > pg_t && break
114129 end
115130 end
116131 return x
117132end
118133
119- # ////////////////////////////////////////////////////////////////////////////////
120- # // Sample //
121- # ////////////////////////////////////////////////////////////////////////////////
122-
123- function draw_like_devroye (rng:: AbstractRNG , c:: Real )
134+ # Sample from PG(1, z)
135+ # Algorithm 1 from "Bayesian Inference for logistic models..." p. 26
136+ function sample_pg1 (rng:: AbstractRNG , z:: Real )
124137 # Change the parameter.
125- c = abs (c ) / 2
138+ z = abs (z ) / 2
126139
127140 # Now sample 0.25 * J^*(1, Z := Z/2).
128- fz = 0.125 * π^ 2 + c^ 2 / 2
129- # ... Problems with large Z? Try using q_over_p.
130- # double p = 0.5 * __PI * exp(-1.0 * fz * __TRUNC) / fz;
131- # double q = 2 * exp(-1.0 * Z) * pigauss(__TRUNC, Z);
132-
133- x = 0.0
134- s = 1.0
135- y = 0.0
136- # int iter = 0; If you want to keep track of iterations.
137- d_exp = Exponential ()
141+ K = π^ 2 / 8 + z^ 2 / 2
142+ t = pg_t
143+
144+ r = mass_texpon (z)
145+
138146 while true
139- if rand (rng) < mass_texpon (c)
140- x = __TRUNC + rand (rng, d_exp) / fz
141- else
142- x = rtigauss (rng, c )
147+ if r > rand (rng) # sample from truncated exponential
148+ x = t + rand (rng, Exponential ()) / K
149+ else # sample from truncated inverse Gaussian
150+ x = rand_truncated_inverse_gaussian (rng,z )
143151 end
144152 s = a (0 , x)
145153 y = rand (rng) * s
146154 n = 0
147- go = true
148-
149- # Cap the number of iterations?
150- while (go)
155+ while true
151156 n = n + 1
152157 if isodd (n)
153158 s = s - a (n, x)
154- if y <= s
155- return 0.25 * x
156- end
159+ y <= s && return x / 4
157160 else
158161 s = s + a (n, x)
159- if y > s
160- go = false
161- end
162+ y > s && break
162163 end
163164 end
164- # Need Y <= S in event that Y = S, e.g. when x = 0.
165165 end
166- end # draw_like_devroye
166+ end # Sample PG(1, c)
167+
168+ # Sample ω as the series of Gamma variables (truncated at 200)
169+ function rand_gamma_sum (rng:: AbstractRNG , d:: PolyaGamma , e:: Real )
170+ C = inv2π / π
171+ c = d. c
172+ w = (c * inv2π)^ 2
173+ d = Gamma (e, 1 )
174+ return C * sum (1 : 200 ) do k
175+ rand (rng, d) / ((k - 0.5 )^ 2 + w)
176+ end
177+ end
0 commit comments