Skip to content

Commit eb5eb7f

Browse files
authored
Add better testing and update on multi-output GPs (#109)
* Fixing issues with multi-output * Passing tests * Fix test data/utils * Use Plots.jl
1 parent e0411a8 commit eb5eb7f

File tree

15 files changed

+147
-157
lines changed

15 files changed

+147
-157
lines changed

docs/examples/heteroscedastic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ model = VGP(
4141
deepcopy(kernel),
4242
HeteroscedasticLikelihood(λ),
4343
AnalyticVI();
44-
optimiser = true, # We optimise both the mean parameters and kernel hyperparameters
45-
mean = μ₀,
46-
verbose = 1
44+
optimiser=true, # We optimise both the mean parameters and kernel hyperparameters
45+
mean=μ₀,
46+
verbose=1,
4747
)
4848

4949
# Model training, we train for around 100 iterations to wait for the convergence of the hyperparameters

src/data/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function view_y(l::AbstractLikelihood, d::MODataContainer, i::AbstractVector)
77
return view_y.(l, output(d), Ref(i))
88
end
99
function view_y(
10-
l::AbstractVector{<:AbstractLikelihood}, d::MODataContainer, i::AbstractVector
10+
l::Tuple{Vararg{<:AbstractLikelihood}}, d::MODataContainer, i::AbstractVector
1111
)
1212
return view_y.(l, output(d), Ref(i))
1313
end
@@ -22,7 +22,7 @@ function wrap_data(X, y, likelihood::AbstractLikelihood)
2222
return wrap_data(X, y)
2323
end
2424

25-
function wrap_data(X, y, likelihoods::AbstractVector{<:AbstractLikelihood})
25+
function wrap_data(X, y, likelihoods::Tuple{Vararg{<:AbstractLikelihood}})
2626
ys = map(check_data!, y, likelihoods)
2727
return wrap_modata(X, ys)
2828
end

src/hyperparameter/autotuning_utils.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,16 @@ function update_kernel!(opt, x::AbstractArray, g::AbstractArray, state)
6767
end
6868

6969
## Updating inducing points
70-
# function update_Z!(opt, Z::Union{ColVecs,RowVecs}, Z_grads::NamedTuple, state)
71-
# return Z.X .+= Optimise.apply!(opt, Z.X, Z_grads.X)
72-
# end
73-
7470
function update_Z!(opt, Z::AbstractVector, Z_grads::AbstractVector, state)
7571
return map(Z, Z_grads, state) do z, zgrad, st
7672
st, ΔZ = Optimisers.apply(opt, st, z, zgrad)
7773
z .+= ΔZ
7874
return st
7975
end
8076
end
77+
78+
function update_Z!(opt, Z::Union{ColVecs,RowVecs}, Z_grads::NamedTuple, state)
79+
st, Δ = Optimisers.apply(opt, state, Z.X, Z_grads.X)
80+
Z.X .+= Δ
81+
return st
82+
end

src/inference/analyticVI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ end
9595
mean_f(m, state.kernel_matrices),
9696
var_f(m, state.kernel_matrices),
9797
) # Compute the local updates given the expectations of f
98-
state = merge(state, (;local_vars))
98+
state = merge(state, (; local_vars))
9999
natural_gradient!.(
100100
m.f,
101101
∇E_μ(m, y, state),

src/likelihood/heteroscedastic.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ function local_updates!(
7878
0.5 * l.invlink.λ[1] * local_vars.ϕ * safe_expcosh(-0.5 * μ[2], 0.5 * local_vars.c)
7979
@. local_vars.θ = 0.5 * (0.5 + local_vars.γ) / local_vars.c * tanh(0.5 * local_vars.c)
8080
@. local_vars.σg = expectation(logistic, μ[2], diagΣ[2])
81-
l.invlink.λ .= max(0.5 * length(local_vars.ϕ) / dot(local_vars.ϕ, local_vars.σg), l.invlink.λ[1])
81+
l.invlink.λ .= max(
82+
0.5 * length(local_vars.ϕ) / dot(local_vars.ϕ, local_vars.σg), l.invlink.λ[1]
83+
)
8284
return local_vars
8385
end
8486

@@ -130,7 +132,9 @@ function heteroscedastic_expectations!(
130132
Σ::AbstractVector,
131133
)
132134
@. local_vars.σg = expectation(logistic, μ, Σ)
133-
l.invlink.λ .= max(0.5 * length(local_vars.ϕ) / dot(local_vars.ϕ, local_vars.σg), l.invlink.λ[1])
135+
l.invlink.λ .= max(
136+
0.5 * length(local_vars.ϕ) / dot(local_vars.ϕ, local_vars.σg), l.invlink.λ[1]
137+
)
134138
return local_vars
135139
end
136140

@@ -153,11 +157,16 @@ end
153157
end
154158

155159
function compute_proba(
156-
l::HeteroscedasticGaussianLikelihood, μs::Tuple{<:AbstractVector,<:AbstractVector}, σs::Tuple{<:AbstractVector,<:AbstractVector}) where {T<:Real}
160+
l::HeteroscedasticGaussianLikelihood,
161+
μs::Tuple{<:AbstractVector,<:AbstractVector},
162+
σs::Tuple{<:AbstractVector,<:AbstractVector},
163+
) where {T<:Real}
157164
return μs[1], σs[1] + expectation.(Ref(l.invlink), μs[2], σs[2])
158165
end
159166

160-
function predict_y(::HeteroscedasticGaussianLikelihood, μs::Tuple{<:AbstractVector,<:AbstractVector})
167+
function predict_y(
168+
::HeteroscedasticGaussianLikelihood, μs::Tuple{<:AbstractVector,<:AbstractVector}
169+
)
161170
return first(μs) # For predict_y the variance is ignored
162171
end
163172

src/likelihood/poisson.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,25 @@ end
8686

8787
### Global Updates ###
8888

89-
@inline function ∇E_μ(::PoissonLikelihood{<:ScaledLogistic}, ::AOptimizer, y::AbstractVector, state)
89+
@inline function ∇E_μ(
90+
::PoissonLikelihood{<:ScaledLogistic}, ::AOptimizer, y::AbstractVector, state
91+
)
9092
return (0.5 * (y - state.γ),)
9193
end
92-
@inline ∇E_Σ(::PoissonLikelihood{<:ScaledLogistic}, ::AOptimizer, y::AbstractVector, state) = (0.5 * state.θ,)
94+
@inline function ∇E_Σ(
95+
::PoissonLikelihood{<:ScaledLogistic}, ::AOptimizer, y::AbstractVector, state
96+
)
97+
return (0.5 * state.θ,)
98+
end
9399

94100
## ELBO Section ##
95101
function expec_loglikelihood(
96-
l::PoissonLikelihood{<:ScaledLogistic}, ::AnalyticVI, y, μ::AbstractVector, Σ::AbstractVector, state
102+
l::PoissonLikelihood{<:ScaledLogistic},
103+
::AnalyticVI,
104+
y,
105+
μ::AbstractVector,
106+
Σ::AbstractVector,
107+
state,
97108
)
98109
tot = 0.5 * (dot(μ, (y - state.γ)) - dot(state.θ, abs2.(μ)) - dot(state.θ, Σ))
99110
tot += Zygote.@ignore(
@@ -106,7 +117,9 @@ function AugmentedKL(l::PoissonLikelihood{<:ScaledLogistic}, state, y)
106117
return PoissonKL(l, state) + PolyaGammaKL(l, state, y)
107118
end
108119

109-
PoissonKL(l::PoissonLikelihood{<:ScaledLogistic}, state) = PoissonKL(state.γ, l.invlink.λ[1])
120+
function PoissonKL(l::PoissonLikelihood{<:ScaledLogistic}, state)
121+
return PoissonKL(state.γ, l.invlink.λ[1])
122+
end
110123

111124
function PolyaGammaKL(::PoissonLikelihood{<:ScaledLogistic}, state, y)
112125
return PolyaGammaKL(y + state.γ, state.c, state.θ)

src/likelihood/regression.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ include("heteroscedastic.jl")
77
include("matern.jl")
88

99
### Return the labels in a vector of vectors for multiple outputs
10-
function treat_labels!(y::AbstractVector{T}, ::Union{RegressionLikelihood,HeteroscedasticGaussianLikelihood}) where {T}
10+
function treat_labels!(
11+
y::AbstractVector{T}, ::Union{RegressionLikelihood,HeteroscedasticGaussianLikelihood}
12+
) where {T}
1113
T <: Real || throw(ArgumentError("For regression target(s) should be real valued"))
1214
return y
1315
end

src/mean/affinemean.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ end
3535

3636
function (μ₀::AffineMean{T})(x::AbstractVector) where {T<:Real}
3737
# μ₀.nDim == size(x, 1) || error(
38-
# "Number of dimensions of prior weight W (",
39-
# size(μ₀.w),
40-
# ") and X (",
41-
# size(x),
42-
# ") do not match",
38+
# "Number of dimensions of prior weight W (",
39+
# size(μ₀.w),
40+
# ") and X (",
41+
# size(x),
42+
# ") do not match",
4343
# )
4444
return dot.(x, Ref(μ₀.w)) .+ first(μ₀.b)
4545
end

src/models/MOSVGP.jl

Lines changed: 40 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,15 @@
44
Multi-Output Sparse Variational Gaussian Process
55
66
## Arguments
7-
- `X::AbstractArray` : : Input features, if `X` is a matrix the choice of colwise/rowwise is given by the `obsdim` keyword
8-
- `y::AbstractVector{<:AbstractVector}` : Output labels, each vector corresponds to one output dimension
97
- `kernel::Union{Kernel,AbstractVector{<:Kernel}` : covariance function or vector of covariance functions, can be either a single kernel or a collection of kernels for multiclass and multi-outputs models
10-
- `likelihood::Union{AbstractLikelihood,Vector{<:Likelihood}` : Likelihood or vector of likelihoods of the model. For compatibilities, see [`Likelihood Types`](@ref likelihood_user)
8+
- `likelihoods::Union{AbstractLikelihood,Vector{<:Likelihood}` : Likelihood or vector of likelihoods of the model. For compatibilities, see [`Likelihood Types`](@ref likelihood_user)
119
- `inference` : Inference for the model, for compatibilities see the [`Compatibility Table`](@ref compat_table))
1210
- `nLatent::Int` : Number of latent GPs
1311
- `nInducingPoints` : number of inducing points, or collection of inducing points locations
1412
1513
## Keyword arguments
1614
- `verbose::Int` : How much does the model print (0:nothing, 1:very basic, 2:medium, 3:everything)
17-
- `optimiser` : Optimiser used for the kernel parameters. Should be an Optimiser object from the [Flux.jl](https://github.com/FluxML/Flux.jl) library, see list here [Optimisers](https://fluxml.ai/Flux.jl/stable/training/optimisers/) and on [this list](https://github.com/theogf/AugmentedGaussianProcesses.jl/tree/master/src/inference/optimisers.jl). Default is `ADAM(0.001)`
15+
- `optimiser` : Optimiser used for the kernel parameters. Should be an Optimiser object from the [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) library. Default is `ADAM(0.001)`
1816
- `Zoptimiser` : Optimiser used for the inducing points locations
1917
- `Aoptimiser` : Optimiser used for the mixing parameters.
2018
- `atfrequency::Int=1` : Choose how many variational parameters iterations are between hyperparameters optimization
@@ -23,64 +21,46 @@ Multi-Output Sparse Variational Gaussian Process
2321
"""
2422
mutable struct MOSVGP{
2523
T<:Real,
26-
TLikelihood<:AbstractLikelihood,
24+
TLikelihood,
2725
TInference<:AbstractInference,
28-
TData<:AbstractDataContainer,
29-
N,
30-
Q,
31-
} <: AbstractGPModel{T,TLikelihood,TInference,N}
32-
data::TData
33-
nFeatures::Vector{Int64} # Number of features of the GP (equal to number of points)
34-
nf_per_task::Vector{Int64}
26+
N, # Number of tasks
27+
Q, # Number of latent GPs
28+
} <: AbstractGPModel{T,AbstractLikelihood,TInference,N}
29+
nf_per_task::NTuple{N,Int}
3530
f::NTuple{Q,SparseVarLatent}
36-
likelihood::Vector{TLikelihood}
31+
likelihood::TLikelihood
3732
inference::TInference
3833
A::Vector{Vector{Vector{T}}}
3934
A_opt::Any
40-
verbose::Int64
41-
atfrequency::Int64
35+
verbose::Int
36+
atfrequency::Int
4237
trained::Bool
4338
end
4439

4540
function MOSVGP(
46-
X::AbstractArray,
47-
y::AbstractVector{<:AbstractVector},
4841
kernel::Union{Kernel,AbstractVector{<:Kernel}},
49-
likelihood::Union{AbstractLikelihood,AbstractVector{<:AbstractLikelihood}},
42+
likelihoods::Union{
43+
AbstractVector{<:AbstractLikelihood},Tuple{Vararg{<:AbstractLikelihood}}
44+
},
5045
inference::AbstractInference,
51-
nLatent::Int,
52-
nInducingPoints::Union{Int,AbstractVector};
46+
Zs::AbstractVector;
5347
verbose::Int=0,
5448
atfrequency::Int=1,
5549
mean::Union{<:Real,AbstractVector{<:Real},PriorMean}=ZeroMean(),
56-
variance::Real=1.0,
5750
optimiser=ADAM(0.01),
5851
Aoptimiser=ADAM(0.01),
5952
Zoptimiser=false,
60-
obsdim::Int=1,
53+
T::DataType=Float64,
6154
)
62-
@assert length(y) > 0 "y should not be an empty vector"
63-
nTask = length(y)
55+
likelihoods = likelihoods isa AbstractVector ? tuple(likelihoods...) : likelihoods
6456

65-
X, T = wrap_X(X, obsdim)
66-
67-
likelihoods = if likelihood isa AbstractLikelihood
68-
likelihoods = [deepcopy(likelihood) for _ in 1:nTask]
69-
else
70-
likelihood
71-
end
72-
73-
nf_per_task = zeros(Int64, nTask)
74-
corrected_y = Vector(undef, nTask)
75-
for i in 1:nTask
76-
corrected_y[i], nf_per_task[i], likelihoods[i] = check_data!(y[i], likelihoods[i])
77-
end
57+
n_task = length(likelihoods)
58+
nf_per_task = n_latent.(likelihoods)
7859

7960
inference isa AnalyticVI || error("The inference object should be of type `AnalyticVI`")
80-
all(implemented.(likelihood, Ref(inference))) ||
81-
error("The $likelihood is not compatible or implemented with the $inference")
82-
83-
data = wrap_data(X, corrected_y)
61+
all(implemented.(likelihoods, Ref(inference))) || error(
62+
"One (or more) of the likelihoods $likelihoods are not compatible or implemented with the $inference",
63+
)
8464

8565
if mean isa Real
8666
mean = ConstantMean(mean)
@@ -92,74 +72,37 @@ function MOSVGP(
9272
optimiser = optimiser ? ADAM(0.01) : nothing
9373
end
9474

75+
if isa(Zoptimiser, Bool)
76+
Zoptimiser = Zoptimiser ? ADAM(0.001) : nothing
77+
end
78+
9579
if isa(Aoptimiser, Bool)
9680
Aoptimiser = Aoptimiser ? ADAM(0.01) : nothing
9781
end
9882

9983
kernel = if kernel isa Kernel
100-
[kernel]
84+
(kernel,)
10185
else
102-
length(kernel) == nLatent ||
86+
length(kernel) == n_task ||
10387
error("Number of kernels should be equal to the number of tasks")
10488
kernel
10589
end
106-
nKernel = length(kernel)
107-
108-
nInducingPoints =
109-
if nInducingPoints isa AbstractVector{<:AbstractVector{<:AbstractVector}}
110-
nInducingPoints
111-
elseif nInducingPoints isa AbstractVector{<:AbstractVector}
112-
[deepcopy(nInducingPoints) for _ in 1:nLatent]
113-
elseif nInducingPoints isa Int
114-
Zref = InducingPoints(KMeansAlg(nInducingPoints), X)
115-
[deepcopy(Zref) for _ in 1:nLatent]
116-
end
117-
118-
nFeatures = size.(Z, 1)
119-
120-
_nMinibatch = nSamples(data)
121-
if is_stochastic(inference)
122-
0 < nMinibatch(inference) < nSamples || error(
123-
"The size of mini-batch $(nMinibatch(inference)) is incorrect (negative or bigger than number of samples), please set nMinibatch correctly in the inference object",
124-
)
125-
_nMinibatch = nMinibatch(inference)
90+
91+
n_kernel = length(kernel)
92+
93+
num_latent = length(Zs)
94+
95+
latent_f = ntuple(num_latent) do i
96+
SparseVarLatent(T, Zs[i], kernel[mod1(i, n_kernel)], mean, optimiser, Zoptimiser)
12697
end
12798

128-
latent_f = ntuple(
129-
i -> _SVGP{T}(
130-
nFeatures[i],
131-
_nMinibatch,
132-
Z[mod(i, nLatent) + 1],
133-
kernel[mod(i, nKernel) + 1],
134-
mean,
135-
optimiser,
136-
Zoptimiser,
137-
),
138-
nLatent,
139-
)
99+
function normalize(x)
100+
return x / sqrt(sum(abs2, x))
101+
end
102+
A = [[normalize(randn(T, num_latent)) for i in 1:nf_per_task[j]] for j in 1:n_task]
140103

141-
A = [
142-
[x -> x / sqrt(sum(abs2, x))(randn(T, nLatent)) for i in 1:nf_per_task[j]] for
143-
j in 1:nTask
144-
]
145-
146-
likelihoods .=
147-
init_likelihood.(likelihoods, inference, nf_per_task, _nMinibatch, nFeatures)
148-
xview = view_x(data, collect(range(1, _nMinibatch; step=1)))
149-
yview = view_y(likelihood, data, 1:nSamples(data))
150-
inference = tuple_inference(
151-
inference, nLatent, nFeatures, nSamples(data), _nMinibatch, xview, yview
152-
)
153104

154-
return MOSVGP{T,eltype(likelihoods),typeof(inference),nTask,nLatent}(
155-
X,
156-
corrected_y,
157-
nSamples,
158-
nDim,
159-
nFeatures,
160-
nLatent,
161-
nX,
162-
nTask,
105+
return MOSVGP{T,typeof(likelihoods),typeof(inference),n_task,num_latent}(
163106
nf_per_task,
164107
latent_f,
165108
likelihoods,
@@ -181,6 +124,6 @@ end
181124

182125
@traitimpl IsMultiOutput{MOSVGP}
183126

184-
nOutput(::MOSVGP{<:Real,<:AbstractLikelihood,<:AbstractInference,N,Q}) where {N,Q} = Q
127+
n_output(::MOSVGP{T,L,I,N,Q}) where {T,L,I,N,Q} = Q
185128
Zviews(m::MOSVGP) = Zview.(m.f)
186129
objective(m::MOSVGP, state, y) = ELBO(m, state, y)

0 commit comments

Comments
 (0)