@@ -40,8 +40,7 @@ mutable struct SVGP{L<:Likelihood,I<:Inference,T<:Real,V<:AbstractVector{T}} <:
4040 Σ:: LatentArray{Symmetric{T,Matrix{T}}}
4141 η₁:: LatentArray{V}
4242 η₂:: LatentArray{Symmetric{T,Matrix{T}}}
43- μ₀:: LatentArray{V}
44- opt_μ₀:: LatentArray{Optimizer}
43+ μ₀:: LatentArray{MeanPrior{T}}
4544 Kmm:: LatentArray{Symmetric{T,Matrix{T}}}
4645 invKmm:: LatentArray{Symmetric{T,Matrix{T}}}
4746 Knm:: LatentArray{Matrix{T}}
6059function SVGP (X:: AbstractArray{T1} ,y:: AbstractArray{T2} ,kernel:: Union{Kernel,AbstractVector{<:Kernel}} ,
6160 likelihood:: LikelihoodType ,inference:: InferenceType ,
6261 nInducingPoints:: Integer
63- ;verbose:: Integer = 0 ,Autotuning:: Bool = true ,atfrequency:: Integer = 1 ,μ₀:: AbstractVector{T1} = Vector {T1} (),
62+ ;verbose:: Integer = 0 ,Autotuning:: Bool = true ,atfrequency:: Integer = 1 ,
63+ mean:: Union{<:Real,AbstractVector{<:Real},MeanPrior} = ZeroMean (),
6464 IndependentPriors:: Bool = true , OptimizeInducingPoints:: Bool = false ,ArrayType:: UnionAll = Vector) where {T1<: Real ,T2,LikelihoodType<: Likelihood ,InferenceType<: Inference }
6565
6666 X,y,nLatent,likelihood = check_data! (X,y,likelihood)
@@ -83,12 +83,14 @@ function SVGP(X::AbstractArray{T1},y::AbstractArray{T2},kernel::Union{Kernel,Abs
8383 Knm = deepcopy (κ)
8484 K̃ = LatentArray ([zeros (T1,inference. Stochastic ? inference. nSamplesUsed : nSample) for _ in 1 : nPrior])
8585 Kmm = LatentArray ([similar (Σ[1 ]) for _ in 1 : nPrior]); invKmm = similar .(Kmm)
86- if ! isempty (μ₀) && length (μ₀) == nFeature
87- μ₀ = [μ₀ for _ in 1 : nPrior]
86+ μ₀ = []
87+ if typeof (mean) <: Real
88+ μ₀ = [ConstantMean (mean) for _ in 1 : nPrior]
89+ elseif typeof (mean) <: AbstractVector{<:Real}
90+ μ₀ = [EmpiricalMean (mean) for _ in 1 : nPrior]
8891 else
89- μ₀ = [zeros (T1,nFeature) for _ in 1 : nPrior]
92+ μ₀ = [mean for _ in 1 : nPrior]
9093 end
91- opt_μ₀ = [Adam (α= 1.0 ) for _ in 1 : nPrior]
9294
9395 nSamplesUsed = nSample
9496 if inference. Stochastic
@@ -97,7 +99,6 @@ function SVGP(X::AbstractArray{T1},y::AbstractArray{T2},kernel::Union{Kernel,Abs
9799 opt = kernel[1 ]. fields. variance. opt
98100 opt. α = opt. α* 0.1
99101 setoptimizer! .(kernel,[copy (opt) for _ in 1 : nLatent])
100- broadcast (opt-> opt. α= opt. α* 0.1 ,opt_μ₀)
101102 end
102103
103104 likelihood = init_likelihood (likelihood,inference,nLatent,nSamplesUsed)
@@ -106,7 +107,7 @@ function SVGP(X::AbstractArray{T1},y::AbstractArray{T2},kernel::Union{Kernel,Abs
106107 nSample, nDim, nFeature, nLatent,
107108 IndependentPriors,nPrior,
108109 Z,μ,Σ,η₁,η₂,
109- μ₀,opt_μ₀, Kmm,invKmm,Knm,κ,K̃,
110+ μ₀,Kmm,invKmm,Knm,κ,K̃,
110111 kernel,likelihood,inference,
111112 verbose,Autotuning,atfrequency,OptimizeInducingPoints,false )
112113 if isa (inference. optimizer_η₁[1 ],ALRSVI)
0 commit comments