Skip to content

Commit 71a9264

Browse files
authored
Improve heteroscedastic (#113)
* Replace first by only when needed and divide by 2 instead of multiplying by 0.5 * More transitions from first to only * Fixed the quadrature * Use only for prior tests * Fixing logistic-softmax * Fix introduced issue with sampling * Add sampling method for heteroscedastic gaussian * Fix bug from heteroscedasticity * Fixed formulations and abuse of `map!` * Abuse of map! * Missing coma * Fixingy fixes * Fixed the uses of map! * Added fixes * Beautiful modifications * Finally fixed the PG distributions! * Handle case d.b < 1 * Patch bump
1 parent fdc70d2 commit 71a9264

File tree

16 files changed

+268
-245
lines changed

16 files changed

+268
-245
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AugmentedGaussianProcesses"
22
uuid = "38eea1fd-7d7d-5162-9d08-f89d0f2e271e"
33
authors = ["Theo Galy-Fajou <[email protected]>"]
4-
version = "0.11.2"
4+
version = "0.11.3"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/examples/heteroscedastic.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using AugmentedGaussianProcesses
55
using Distributions
66
using LinearAlgebra
77
using Plots
8+
using Random
89
default(; lw=3.0, msw=0.0)
910
# using CairoMakie
1011

@@ -15,16 +16,17 @@ default(; lw=3.0, msw=0.0)
1516
# ``y \sim f + \epsilon``
1617
# where ``\epsilon \sim \mathcal{N}(0, (\lambda \sigma(g))^{-1})``
1718
# We create a toy dataset with X ∈ [-10, 10] and sample `f`, `g` and `y` given this same generative model
19+
rng = MersenneTwister(42)
1820
N = 200
19-
x = (sort(rand(N)) .- 0.5) * 20.0
21+
x = (sort(rand(rng, N)) .- 0.5) * 20.0
2022
x_test = range(-10, 10; length=500)
2123
kernel = 5.0 * SqExponentialKernel() ScaleTransform(1.0) # Kernel function
2224
K = kernelmatrix(kernel, x) + 1e-5I # The kernel matrix
23-
f = rand(MvNormal(K)); # We draw a random sample from the GP prior
25+
f = rand(rng, MvNormal(K)); # We draw a random sample from the GP prior
2426

2527
# We add a prior mean on `g` so that the variance does not become too big
2628
μ₀ = -3.0
27-
g = rand(MvNormal(μ₀ * ones(N), K))
29+
g = rand(rng, MvNormal(μ₀ * ones(N), K))
2830
λ = 3.0 # The maximum possible precision
2931
σ = inv.(sqrt.(λ * AGP.logistic.(g))) # We use the following transform to obtain the std. deviation
3032
y = f + σ .* randn(N); # We finally sample the ouput
@@ -38,7 +40,7 @@ scatter!(x, y; alpha=0.5, msw=0.0, lab="y") # Observation samples
3840
model = VGP(
3941
x,
4042
y,
41-
deepcopy(kernel),
43+
kernel,
4244
HeteroscedasticLikelihood(λ),
4345
AnalyticVI();
4446
optimiser=true, # We optimise both the mean parameters and kernel hyperparameters

docs/src/userguide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ Not all inference are implemented/valid for all likelihoods, here is the compati
114114
| GaussianLikelihood | ✔ (Analytic) ||||
115115
| StudentTLikelihood |||||
116116
| LaplaceLikelihood |||||
117-
| HeteroscedasticLikelihood || (dev) | (dev) ||
117+
| HeteroscedasticLikelihood || | (dev) ||
118118
| LogisticLikelihood |||||
119119
| BayesianSVM || (dev) |||
120120
| LogisticSoftMaxLikelihood |||| (dev) |

src/ComplementaryDistributions/ComplementaryDistributions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ComplementaryDistributions
33
using Distributions
44
using Random
55
using SpecialFunctions
6-
using StatsFuns: twoπ
6+
using StatsFuns: twoπ, halfπ, inv2π, fourinvπ
77

88
export GeneralizedInverseGaussian, PolyaGamma, LaplaceTransformDistribution
99
include("generalizedinversegaussian.jl")
Lines changed: 106 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using Distributions, Random
22
using Statistics
33
using SpecialFunctions
4-
const __TRUNC = 0.64;
5-
const __TRUNC_RECIP = 1.0 / __TRUNC;
4+
const pg_t = 0.64
5+
const pg_inv_t = inv(pg_t)
6+
67
"""
7-
PolyaGamma(b::Int, c::Real)
8+
PolyaGamma(b::Real, c::Real)
89
910
## Arguments
10-
- `b::Int`
11+
- `b::Real`
1112
- `c::Real` exponential tilting
1213
1314
## Keyword Arguments
@@ -16,151 +17,161 @@ const __TRUNC_RECIP = 1.0 / __TRUNC;
1617
1718
Create a PolyaGamma sampler with parameters `b` and `c`
1819
"""
19-
struct PolyaGamma{Tc,A} <: Distributions.ContinuousUnivariateDistribution
20+
struct PolyaGamma{Tb,Tc} <: Distributions.ContinuousUnivariateDistribution
2021
# For sum of Gammas.
21-
b::Int
22+
b::Tb
2223
c::Tc
23-
trunc::Int
24-
nmax::Int
25-
bvec::A
26-
#Constructor
27-
function PolyaGamma{T}(b::Int, c::T, trunc::Int, nmax::Int) where {T<:Real}
28-
if trunc < 1
29-
@warn "trunc < 1. Setting trunc=1."
30-
trunc = 1
31-
end
32-
bvec = [convert(T, (twoπ * (k - 0.5))^2) for k in 1:trunc]
33-
return new{typeof(c),typeof(bvec)}(b, c, trunc, nmax, bvec)
34-
end
3524
end
3625

26+
Base.eltype(::PolyaGamma{T,Tc}) where {T,Tc} = Tc
27+
28+
Distributions.params(d::PolyaGamma) = (d.b, d.c)
29+
3730
Statistics.mean(d::PolyaGamma) = d.b / (2 * d.c) * tanh(d.c / 2)
3831

39-
function PolyaGamma(b::Int, c::T; nmax::Int=10, trunc::Int=200) where {T<:Real}
40-
return PolyaGamma{T}(b, c, trunc, nmax)
32+
Base.minimum(d::PolyaGamma) = zero(eltype(d))
33+
Base.maximum(::PolyaGamma) = Inf
34+
Distributions.insupport(::PolyaGamma, x::Real) = zero(x) <= x < Inf
35+
36+
function Distributions.pdf(d::PolyaGamma, x::Real)
37+
b, c = params(d)
38+
iszero(x) && return zero(x)
39+
return _tilt(x, b, c) * 2^(b - 1) / gamma(b) * sum(0:200) do n
40+
ifelse(iseven(n), 1, -1) * exp(
41+
loggamma(n + b) - loggamma(n + 1) + log(2n + b) - log(twoπ * x^3) / 2 -
42+
(2n + b)^2 / (8x),
43+
)
44+
end
4145
end
4246

43-
function Distributions.pdf(d::PolyaGamma, x)
44-
return cosh(d.c / 2)^d.b * 2.0^(d.b - 1) / gamma(d.b) * sum(
45-
((-1)^n) * gamma(n + d.b) / gamma(n + 1) * (2 * n + b) / (sqrt(2 * π * x^3)) *
46-
exp(-(2 * n + b)^2 / (8 * x) - c^2 / 2 * x) for n in 0:(d.nmax)
47-
)
47+
function _tilt(ω, b, c)
48+
return cosh(c / 2)^b * exp(-c^2 / 2 * ω)
4849
end
4950

50-
## Sampling
51-
function Distributions.rand(rng::AbstractRNG, d::PolyaGamma{T}) where {T<:Real}
51+
function Distributions.rand(rng::AbstractRNG, d::PolyaGamma)
5252
if iszero(d.b)
53-
return zero(T)
53+
return zero(eltype(d))
54+
end
55+
return draw_sum(rng, d)
56+
end
57+
58+
## Sampling when `b` is an integer
59+
function draw_sum(rng::AbstractRNG, d::PolyaGamma{<:Int})
60+
return sum(Base.Fix1(sample_pg1, rng), d.c * ones(d.b))
61+
end
62+
63+
function draw_sum(rng::AbstractRNG, d::PolyaGamma{<:Real})
64+
if d.b < 1
65+
return rand_gamma_sum(rng, d, d.b)
5466
end
55-
return sum(Base.Fix1(draw_like_devroye, rng), d.c * ones(d.b))
67+
trunc_b = floor(Int, d.b)
68+
res_b = d.b - trunc_b
69+
trunc_term = sum(Base.Fix1(sample_pg1, rng), d.c * ones(trunc_b))
70+
res_term = rand_gamma_sum(rng, d, res_b)
71+
return trunc_term + res_term
5672
end
5773

5874
## Utility functions
5975
function a(n::Int, x::Real)
6076
k = (n + 0.5) * π
61-
if x > __TRUNC
77+
if x > pg_t
6278
return k * exp(-k^2 * x / 2)
6379
elseif x > 0
64-
expnt = -1.5 * (log(π / 2) + log(x)) + log(k) - 2 * (n + 0.5)^2 / x
80+
expnt = -3 / 2 * (log(halfπ) + log(x)) + log(k) - 2 * (n + 1//2)^2 / x
6581
return exp(expnt)
82+
else
83+
error("x should be a positive real")
6684
end
6785
end
6886

6987
function mass_texpon(z::Real)
70-
t = __TRUNC
88+
t = pg_t
7189

72-
fz = 0.125 * π^2 + z^2 / 2
73-
b = sqrt(1.0 / t) * (t * z - 1)
74-
a = sqrt(1.0 / t) * (t * z + 1) * -1.0
90+
K = π^2 / 8 + z^2 / 2
91+
b = sqrt(inv(t)) * (t * z - 1)
92+
a = -sqrt(inv(t)) * (t * z + 1)
7593

76-
x0 = log(fz) + fz * t
94+
x0 = log(K) + K * t
7795
xb = x0 - z + logcdf(Distributions.Normal(), b)
7896
xa = x0 + z + logcdf(Distributions.Normal(), a)
7997

80-
qdivp = 4 / π * (exp(xb) + exp(xa))
98+
qdivp = fourinvπ * (exp(xb) + exp(xa))
8199

82-
return 1.0 / (1.0 + qdivp)
100+
return 1 / (1 + qdivp)
83101
end
84102

85-
function rtigauss(rng::AbstractRNG, z::Real)
86-
z = abs(z)
87-
t = __TRUNC
88-
x = t + 1.0
89-
if __TRUNC_RECIP > z
90-
alpha = 0.0
91-
rate = 1.0
92-
d_exp = Exponential(1.0 / rate)
93-
while (rand(rng) > alpha)
94-
e1 = rand(rng, d_exp)
95-
e2 = rand(rng, d_exp)
96-
while e1^2 > 2 * e2 / t
97-
e1 = rand(rng, d_exp)
98-
e2 = rand(rng, d_exp)
103+
# Sample from a truncated inverse gaussian
104+
function rand_truncated_inverse_gaussian(rng::AbstractRNG, z::Real)
105+
μ = inv(z)
106+
x = one(z) + pg_t
107+
if μ > pg_t
108+
d_exp = Exponential()
109+
while true
110+
E = rand(rng, d_exp)
111+
E′ = rand(rng, d_exp)
112+
while E^2 > 2E′ / pg_t
113+
E = rand(rng, d_exp)
114+
E′ = rand(rng, d_exp)
99115
end
100-
x = 1 + e1 * t
101-
x = t / x^2
102-
alpha = exp(-z^2 * x / 2)
116+
x = pg_t / (1 + E * pg_t)^2
117+
α = exp(-z^2 * x / 2)
118+
α >= rand(rng) && break
103119
end
104120
else
105-
mu = 1.0 / z
106-
while (x > t)
107-
y = randn(rng)^2
108-
half_mu = mu / 2
109-
mu_Y = mu * y
110-
x = mu + half_mu * mu_Y - half_mu * sqrt(4 * mu_Y + mu_Y^2)
111-
if rand(rng) > mu / (mu + x)
112-
x = mu^2 / x
121+
while (x > pg_t)
122+
Y = randn(rng)^2
123+
μY = μ * Y
124+
x = μ + μ * μY / 2 - μ / 2 * sqrt(4 * μY + μY^2)
125+
if rand(rng) > μ /+ x)
126+
x = μ^2 / x
113127
end
128+
x > pg_t && break
114129
end
115130
end
116131
return x
117132
end
118133

119-
# ////////////////////////////////////////////////////////////////////////////////
120-
# // Sample //
121-
# ////////////////////////////////////////////////////////////////////////////////
122-
123-
function draw_like_devroye(rng::AbstractRNG, c::Real)
134+
# Sample from PG(1, z)
135+
# Algorithm 1 from "Bayesian Inference for logistic models..." p. 26
136+
function sample_pg1(rng::AbstractRNG, z::Real)
124137
# Change the parameter.
125-
c = abs(c) / 2
138+
z = abs(z) / 2
126139

127140
# Now sample 0.25 * J^*(1, Z := Z/2).
128-
fz = 0.125 * π^2 + c^2 / 2
129-
# ... Problems with large Z? Try using q_over_p.
130-
# double p = 0.5 * __PI * exp(-1.0 * fz * __TRUNC) / fz;
131-
# double q = 2 * exp(-1.0 * Z) * pigauss(__TRUNC, Z);
132-
133-
x = 0.0
134-
s = 1.0
135-
y = 0.0
136-
# int iter = 0; If you want to keep track of iterations.
137-
d_exp = Exponential()
141+
K = π^2 / 8 + z^2 / 2
142+
t = pg_t
143+
144+
r = mass_texpon(z)
145+
138146
while true
139-
if rand(rng) < mass_texpon(c)
140-
x = __TRUNC + rand(rng, d_exp) / fz
141-
else
142-
x = rtigauss(rng, c)
147+
if r > rand(rng) # sample from truncated exponential
148+
x = t + rand(rng, Exponential()) / K
149+
else # sample from truncated inverse Gaussian
150+
x = rand_truncated_inverse_gaussian(rng,z)
143151
end
144152
s = a(0, x)
145153
y = rand(rng) * s
146154
n = 0
147-
go = true
148-
149-
# Cap the number of iterations?
150-
while (go)
155+
while true
151156
n = n + 1
152157
if isodd(n)
153158
s = s - a(n, x)
154-
if y <= s
155-
return 0.25 * x
156-
end
159+
y <= s && return x / 4
157160
else
158161
s = s + a(n, x)
159-
if y > s
160-
go = false
161-
end
162+
y > s && break
162163
end
163164
end
164-
# Need Y <= S in event that Y = S, e.g. when x = 0.
165165
end
166-
end # draw_like_devroye
166+
end # Sample PG(1, c)
167+
168+
# Sample ω as the series of Gamma variables (truncated at 200)
169+
function rand_gamma_sum(rng::AbstractRNG, d::PolyaGamma, e::Real)
170+
C = inv2π / π
171+
c = d.c
172+
w = (c * inv2π)^2
173+
d = Gamma(e, 1)
174+
return C * sum(1:200) do k
175+
rand(rng, d) / ((k - 0.5)^2 + w)
176+
end
177+
end

src/functions/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ function expectation(f, μ::Real, σ²::Real)
1818
return dot(pred_weights, f.(x))
1919
end
2020

21+
# Return √E[f^2]
22+
function sqrt_expec_square(μ, σ²)
23+
return sqrt(abs2(μ) + σ²)
24+
end
25+
26+
# Return √E[(f-y)^2]
27+
function sqrt_expec_square(μ, σ², y)
28+
return sqrt(abs2- y) + σ²)
29+
end
30+
2131
## delta function `(i,j)`, equal `1` if `i == j`, `0` else ##
2232
@inline function δ(T, i::Integer, j::Integer)
2333
return ifelse(i == j, one(T), zero(T))

src/likelihood/bayesiansvm.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ function local_updates!(
4444
μ::AbstractVector{T},
4545
diagΣ::AbstractVector,
4646
) where {T}
47-
@. local_vars.c = abs2(one(T) - y * μ) + diagΣ
48-
@. local_vars.θ = inv(sqrt(local_vars.c))
47+
map!(local_vars.c, μ, diagΣ, y) do μ, σ², y
48+
abs2(one(T) - y * μ) + σ²
49+
end
50+
map!(inv sqrt, local_vars.θ, local_vars.c)
4951
return local_vars
5052
end
5153

0 commit comments

Comments
 (0)