Skip to content

Commit 29d9ef2

Browse files
committed
Added custom priors (constant and zero)
1 parent e716c85 commit 29d9ef2

File tree

9 files changed

+104
-39
lines changed

9 files changed

+104
-39
lines changed

src/AugmentedGaussianProcesses.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ include("kernels/KernelModule.jl")
2424
include("kmeans/KMeansModule.jl")
2525
include("functions/PGSampler.jl")
2626
#include("functions/PerturbativeCorrection.jl")
27-
include("functions/GPAnalysisTools.jl")
27+
# include("functions/GPAnalysisTools.jl")
2828
# include("functions/IO_model.jl")
2929
#Custom modules
3030
using .KernelModule
@@ -66,6 +66,7 @@ abstract type Inference{T<:Real} end
6666
abstract type Likelihood{T<:Real} end
6767

6868
const LatentArray = Vector #For future optimization : How collection of latent GPs are stored
69+
include("prior/meanprior.jl")
6970

7071
include("models/AbstractGP.jl")
7172
include("models/GP.jl")

src/autotuning.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function update_hyperparameters!(model::Union{VGP,GP})
88

99
apply_gradients_lengthscale!.(model.kernel,grads_l) #Send the derivative of the matrix to the specific gradient of the model
1010
apply_gradients_variance!.(model.kernel,grads_v) #Send the derivative of the matrix to the specific gradient of the model
11-
apply_gradients_mean_prior!.(model.μ₀,grads_μ₀,model.opt_μ₀)
11+
update!.(model.μ₀,grads_μ₀)
1212

1313
model.inference.HyperParametersUpdated = true
1414
end
@@ -33,7 +33,7 @@ function update_hyperparameters!(model::SVGP{<:Likelihood,<:Inference,T}) where
3333
grads_μ₀ = map(f_μ₀,1:model.nPrior)
3434
apply_gradients_lengthscale!.(model.kernel,grads_l)
3535
apply_gradients_variance!.(model.kernel,grads_v)
36-
apply_gradients_mean_prior!.(model.μ₀,grads_μ₀,model.opt_μ₀)
36+
update!.(model.μ₀,grads_μ₀)
3737
model.inference.HyperParametersUpdated = true
3838
end
3939

src/models/GP.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ mutable struct GP{L<:Likelihood,I<:Inference,T<:Real,V<:AbstractVector{T}} <: Ab
3434
nLatent::Int64 # Number pf latent GPs
3535
IndependentPriors::Bool # Use of separate priors for each latent GP
3636
nPrior::Int64 # Equal to 1 or nLatent given IndependentPriors
37-
μ₀::LatentArray{V}
38-
opt_μ₀::LatentArray{Optimizer}
37+
μ₀::LatentArray{MeanPrior{T}}
3938
Knn::LatentArray{Symmetric{T,Matrix{T}}}
4039
invKnn::LatentArray{Symmetric{T,Matrix{T}}}
4140
kernel::LatentArray{Kernel{T}}
@@ -49,7 +48,7 @@ end
4948

5049

5150
function GP(X::AbstractArray{T1,N1},y::AbstractArray{T2,N2},kernel::Union{Kernel,AbstractVector{<:Kernel}}; noise::Real=1e-5,
52-
verbose::Integer=0,Autotuning::Bool=true,atfrequency::Integer=1,μ₀::AbstractVector{T1}=Vector{T1}(),
51+
verbose::Integer=0,Autotuning::Bool=true,atfrequency::Integer=1,mean::Union{<:Real,AbstractVector{<:Real},MeanPrior}=ZeroMean(),
5352
IndependentPriors::Bool=true,ArrayType::UnionAll=Vector) where {T1<:Real,T2,N1,N2}
5453
likelihood = GaussianLikelihood(noise)
5554
inference = Analytic()
@@ -61,19 +60,21 @@ function GP(X::AbstractArray{T1,N1},y::AbstractArray{T2,N2},kernel::Union{Kernel
6160

6261
Knn = LatentArray([Symmetric(Matrix{T1}(I,nFeature,nFeature)) for _ in 1:nPrior]);
6362
invKnn = copy(Knn)
64-
if !isempty(μ₀) && length(μ₀) == nFeature
65-
μ₀ = [μ₀ for _ in 1:nPrior]
63+
μ₀ = []
64+
if typeof(mean) <: Real
65+
μ₀ = [ConstantMean(mean) for _ in 1:nPrior]
66+
elseif typeof(mean) <: AbstractVector{<:Real}
67+
μ₀ = [EmpiricalMean(mean) for _ in 1:nPrior]
6668
else
67-
μ₀ = [zeros(T1,nFeature) for _ in 1:nPrior]
69+
μ₀ = [mean for _ in 1:nPrior]
6870
end
69-
opt_μ₀ = [Adam=0.1) for _ in 1:nPrior]
7071
likelihood = init_likelihood(likelihood,inference,nLatent,nSample)
7172
inference = init_inference(inference,nLatent,nSample,nSample,nSample)
7273

7374
model = GP{GaussianLikelihood{T1},Analytic{T1},T1,ArrayType{T1}}(X,y,
7475
nFeature, nDim, nFeature, nLatent,
7576
IndependentPriors,nPrior,
76-
μ₀,opt_μ₀,Knn,invKnn,kernel,likelihood,inference,
77+
μ₀,Knn,invKnn,kernel,likelihood,inference,
7778
verbose,Autotuning,atfrequency,false)
7879
computeMatrices!(model)
7980
model.Trained = true

src/models/SVGP.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ mutable struct SVGP{L<:Likelihood,I<:Inference,T<:Real,V<:AbstractVector{T}} <:
4040
Σ::LatentArray{Symmetric{T,Matrix{T}}}
4141
η₁::LatentArray{V}
4242
η₂::LatentArray{Symmetric{T,Matrix{T}}}
43-
μ₀::LatentArray{V}
44-
opt_μ₀::LatentArray{Optimizer}
43+
μ₀::LatentArray{MeanPrior{T}}
4544
Kmm::LatentArray{Symmetric{T,Matrix{T}}}
4645
invKmm::LatentArray{Symmetric{T,Matrix{T}}}
4746
Knm::LatentArray{Matrix{T}}
@@ -60,7 +59,8 @@ end
6059
function SVGP(X::AbstractArray{T1},y::AbstractArray{T2},kernel::Union{Kernel,AbstractVector{<:Kernel}},
6160
likelihood::LikelihoodType,inference::InferenceType,
6261
nInducingPoints::Integer
63-
;verbose::Integer=0,Autotuning::Bool=true,atfrequency::Integer=1,μ₀::AbstractVector{T1}=Vector{T1}(),
62+
;verbose::Integer=0,Autotuning::Bool=true,atfrequency::Integer=1,
63+
mean::Union{<:Real,AbstractVector{<:Real},MeanPrior}=ZeroMean(),
6464
IndependentPriors::Bool=true, OptimizeInducingPoints::Bool=false,ArrayType::UnionAll=Vector) where {T1<:Real,T2,LikelihoodType<:Likelihood,InferenceType<:Inference}
6565

6666
X,y,nLatent,likelihood = check_data!(X,y,likelihood)
@@ -83,12 +83,14 @@ function SVGP(X::AbstractArray{T1},y::AbstractArray{T2},kernel::Union{Kernel,Abs
8383
Knm = deepcopy(κ)
8484
= LatentArray([zeros(T1,inference.Stochastic ? inference.nSamplesUsed : nSample) for _ in 1:nPrior])
8585
Kmm = LatentArray([similar(Σ[1]) for _ in 1:nPrior]); invKmm = similar.(Kmm)
86-
if !isempty(μ₀) && length(μ₀) == nFeature
87-
μ₀ = [μ₀ for _ in 1:nPrior]
86+
μ₀ = []
87+
if typeof(mean) <: Real
88+
μ₀ = [ConstantMean(mean) for _ in 1:nPrior]
89+
elseif typeof(mean) <: AbstractVector{<:Real}
90+
μ₀ = [EmpiricalMean(mean) for _ in 1:nPrior]
8891
else
89-
μ₀ = [zeros(T1,nFeature) for _ in 1:nPrior]
92+
μ₀ = [mean for _ in 1:nPrior]
9093
end
91-
opt_μ₀ = [Adam=1.0) for _ in 1:nPrior]
9294

9395
nSamplesUsed = nSample
9496
if inference.Stochastic
@@ -97,7 +99,6 @@ function SVGP(X::AbstractArray{T1},y::AbstractArray{T2},kernel::Union{Kernel,Abs
9799
opt = kernel[1].fields.variance.opt
98100
opt.α = opt.α*0.1
99101
setoptimizer!.(kernel,[copy(opt) for _ in 1:nLatent])
100-
broadcast(opt->opt.α=opt.α*0.1,opt_μ₀)
101102
end
102103

103104
likelihood = init_likelihood(likelihood,inference,nLatent,nSamplesUsed)
@@ -106,7 +107,7 @@ function SVGP(X::AbstractArray{T1},y::AbstractArray{T2},kernel::Union{Kernel,Abs
106107
nSample, nDim, nFeature, nLatent,
107108
IndependentPriors,nPrior,
108109
Z,μ,Σ,η₁,η₂,
109-
μ₀,opt_μ₀,Kmm,invKmm,Knm,κ,K̃,
110+
μ₀,Kmm,invKmm,Knm,κ,K̃,
110111
kernel,likelihood,inference,
111112
verbose,Autotuning,atfrequency,OptimizeInducingPoints,false)
112113
if isa(inference.optimizer_η₁[1],ALRSVI)

src/models/VGP.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ end
5656

5757
function VGP(X::AbstractArray{T1,N1},y::AbstractArray{T2,N2},kernel::Union{Kernel,AbstractVector{<:Kernel}},
5858
likelihood::LikelihoodType,inference::InferenceType;
59-
verbose::Integer=0,Autotuning::Bool=true,atfrequency::Integer=1,mean::Union{T,MeanPrior}=ConstantMean(0.0),
59+
verbose::Integer=0,Autotuning::Bool=true,atfrequency::Integer=1,mean::Union{<:Real,AbstractVector{<:Real},MeanPrior}=ZeroMean(),
6060
IndependentPriors::Bool=true,ArrayType::UnionAll=Vector) where {T1<:Real,T2,N1,N2,LikelihoodType<:Likelihood,InferenceType<:Inference}
6161

6262
X,y,nLatent,likelihood = check_data!(X,y,likelihood)
@@ -69,12 +69,14 @@ function VGP(X::AbstractArray{T1,N1},y::AbstractArray{T2,N2},kernel::Union{Kerne
6969
μ = LatentArray([zeros(T1,nFeature) for _ in 1:nLatent]); η₁ = deepcopy(μ)
7070
Σ = LatentArray([Symmetric(Matrix(Diagonal(one(T1)*I,nFeature))) for _ in 1:nLatent]);
7171
η₂ = -0.5*inv.(Σ);
72+
μ₀ = []
7273
if typeof(mean) <: Real
73-
mean = [ConstantMean(mean) for _ in 1:nPrior]
74+
μ₀ = [ConstantMean(mean) for _ in 1:nPrior]
75+
elseif typeof(mean) <: AbstractVector{<:Real}
76+
μ₀ = [EmpiricalMean(mean) for _ in 1:nPrior]
7477
else
75-
mean = [mean for _ in 1:nPrior]
78+
μ₀ = [mean for _ in 1:nPrior]
7679
end
77-
7880
Knn = LatentArray([deepcopy(Σ[1]) for _ in 1:nPrior]);
7981
invKnn = copy(Knn)
8082

@@ -84,7 +86,7 @@ function VGP(X::AbstractArray{T1,N1},y::AbstractArray{T2,N2},kernel::Union{Kerne
8486
VGP{LikelihoodType,InferenceType,T1,ArrayType{T1}}(X,y,
8587
nFeature, nDim, nFeature, nLatent,
8688
IndependentPriors,nPrior,μ,Σ,η₁,η₂,
87-
mean,Knn,invKnn,kernel,likelihood,inference,
89+
μ₀,Knn,invKnn,kernel,likelihood,inference,
8890
verbose,Autotuning,atfrequency,false)
8991
end
9092

src/prior/constantmean.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ function ConstantMean(c::T=1.0;opt::Optimizer=Adam(α=0.01)) where {T<:Real}
1313
end
1414

1515
function update!::ConstantMean{T},grad::AbstractVector{T}) where {T<:Real}
16-
μ.C .+= update!.opt,sum(grad))
16+
μ.C += update!.opt,sum(grad))
1717
end
1818

19-
Base.+(x::Real,y::ConstantMean{<:Real}) = x+y.C
20-
Base.+(x::AbstractVector{<:Real},y::ConstantMean{<:Real}) = x.+y.C
21-
Base.+(x::ConstantMean{<:Real},y::AbstractVector{<:Real}) = y.+x.C
22-
Base.+(x::ConstantMean{<:Real},y::Real) = y+x.C
23-
Base.+(x::ConstantMean{<:Real},y::AbstractVector{<:Real}) = ConstantMean(x.C+y.C)
24-
Base.-(x::Real,y::ConstantMean) = x - y.C
25-
Base.-(x::AbstractVector{<:Real},y::ConstantMean) = x .- y.C
26-
Base.-(x::ConstantMean{<:Real},y::Real) = x.C - y
27-
Base.-(x::ConstantMean{<:Real},y::AbstractVector{<:Real}) = x.C .- y
28-
Base.-(x::ConstantMean{<:Real},y::AbstractVector{<:Real}) = ConstantMean(x.C-y.C)
29-
Base.*(A::AbstractMatrix{<:Real},y::ConstantMean{T}) where {T<:Real} = y.C*A*ones(T,size(A,2),1)
30-
Base.*(y::ConstantMean{T},A::AbstractMatrix{<:Real}) where {T<:Real} = y.C*ones(T,1,size(A,1))*A
31-
Base.convert(::T1,x::ConstantMean{T2}) where {T1<:Real,T2<:Real} = T1(x.C)
19+
Base.:+(x::Real,y::ConstantMean{<:Real}) = x+y.C
20+
Base.:+(x::AbstractVector{<:Real},y::ConstantMean{<:Real}) = x.+y.C
21+
Base.:+(x::ConstantMean{<:Real},y::Real) = y+x.C
22+
Base.:+(x::ConstantMean{<:Real},y::AbstractVector{<:Real}) = y.+x.C
23+
Base.:+(x::ConstantMean{<:Real},y::ConstantMean{<:Real}) = ConstantMean(x.C+y.C)
24+
Base.:-(x::Real,y::ConstantMean) = x - y.C
25+
Base.:-(x::AbstractVector{<:Real},y::ConstantMean) = x .- y.C
26+
Base.:-(x::ConstantMean{<:Real},y::Real) = x.C - y
27+
Base.:-(x::ConstantMean{<:Real},y::AbstractVector{<:Real}) = x.C .- y
28+
Base.:-(x::ConstantMean{<:Real},y::AbstractVector{<:Real}) = ConstantMean(x.C-y.C)
29+
Base.:*(A::AbstractMatrix{<:Real},y::ConstantMean{T}) where {T<:Real} = y.C*A*ones(T,size(A,2))
30+
Base.:*(y::ConstantMean{T},A::AbstractMatrix{<:Real}) where {T<:Real} = y.C*ones(T,1,size(A,1))*A
31+
Base.:convert(::T1,x::ConstantMean{T2}) where {T1<:Real,T2<:Real} = T1(x.C)

src/prior/empiricalmean.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
mutable struct EmpiricalMean{T<:Real,V<:AbstractVector{<:Real}} <: MeanPrior{T}
2+
C::V
3+
opt::Optimizer
4+
end
5+
6+
"""
7+
EmpiricalMean(c)
8+
Construct a constant mean with values `c`
9+
Optionally give an optimizer `opt` (`Adam(α=0.01)` by default)
10+
"""
11+
function EmpiricalMean(c::V=1.0;opt::Optimizer=Adam=0.01)) where {V<:AbstractVector{<:Real}}
12+
EmpiricalMean{eltype(c),V}(c,opt)
13+
end
14+
15+
function update!::EmpiricalMean{T},grad::AbstractVector{T}) where {T<:Real}
16+
μ.C .+= update!.opt,grad)
17+
end
18+
19+
Base.:+(x::Real,y::EmpiricalMean{<:Real}) = x.+y.C
20+
Base.:+(x::AbstractVector{<:Real},y::EmpiricalMean{<:Real}) = x+y.C
21+
Base.:+(x::EmpiricalMean{<:Real},y::Real) = y.+x.C
22+
Base.:+(x::EmpiricalMean{<:Real},y::AbstractVector{<:Real}) = y+x.C
23+
Base.:+(x::EmpiricalMean{<:Real},y::EmpiricalMean{<:Real}) = EmpiricalMean(x.C+y.C)
24+
Base.:-(x::Real,y::EmpiricalMean) = x .- y.C
25+
Base.:-(x::AbstractVector{<:Real},y::EmpiricalMean) = x - y.C
26+
Base.:-(x::EmpiricalMean{<:Real},y::Real) = x.C .- y
27+
Base.:-(x::EmpiricalMean{<:Real},y::AbstractVector{<:Real}) = x.C - y
28+
Base.:-(x::EmpiricalMean{<:Real},y::EmpiricalMean{<:Real}) = EmpiricalMean(x.C-y.C)
29+
Base.:*(A::AbstractMatrix{<:Real},y::EmpiricalMean{T}) where {T<:Real} = A*y.C
30+
Base.:*(y::EmpiricalMean{T},A::AbstractMatrix{<:Real}) where {T<:Real} = transpose(y)*A
31+
Base.:convert(::T1,x::EmpiricalMean{T2}) where {T1<:Real,T2<:Real} = T1(x.C)

src/prior/meanprior.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ abstract type MeanPrior{T} end
33
import Base: +, -, *, convert
44

55
include("constantmean.jl")
6+
include("zeromean.jl")
7+
include("empiricalmean.jl")

src/prior/zeromean.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
mutable struct ZeroMean{T<:Real} <: MeanPrior{T}
2+
end
3+
4+
"""
5+
ZeroMean(c)
6+
Construct a constant mean with constant 0
7+
"""
8+
function ZeroMean()
9+
ZeroMean{Float64}()
10+
end
11+
12+
function update!::ZeroMean{T},grad::AbstractVector{T}) where {T<:Real}
13+
end
14+
15+
Base.:+(x::Real,y::ZeroMean{<:Real}) = x
16+
Base.:+(x::AbstractVector{<:Real},y::ZeroMean{<:Real}) = x
17+
Base.:+(x::ZeroMean{<:Real},y::Real) = y
18+
Base.:+(x::ZeroMean{<:Real},y::AbstractVector{<:Real}) = y
19+
Base.:+(x::ZeroMean{<:Real},y::ConstantMean{<:Real}) = ConstantMean(y.C)
20+
Base.:-(x::Real,y::ZeroMean{<:Real}) = x
21+
Base.:-(x::AbstractVector{<:Real},y::ZeroMean) = x
22+
Base.:-(x::ZeroMean{<:Real},y::Real) = -y
23+
Base.:-(x::ZeroMean{<:Real},y::AbstractVector{<:Real}) = -y
24+
Base.:-(x::ZeroMean{<:Real},y::ConstantMean{<:Real}) = ConstantMean(-y.C)
25+
Base.:*(A::AbstractMatrix{<:Real},y::ZeroMean{T}) where {T<:Real} = zeros(T,size(A,2))
26+
Base.:*(y::ZeroMean{T},A::AbstractMatrix{<:Real}) where {T<:Real} = zeros(T,1,size(A,1))
27+
Base.:convert(::T1,x::ZeroMean{T2}) where {T1<:Real,T2<:Real} = T1(x.C)

0 commit comments

Comments
 (0)