|
51 | 51 |
|
52 | 52 |
|
53 | 53 | function local_updates!(model::VGP{<:AugmentedLogisticLikelihood,<:AnalyticInference}) |
54 | | - model.likelihood.c .= broadcast((μ,Σ)->sqrt.(diag(Σ)+μ.^2),model.μ,model.Σ) |
| 54 | + model.likelihood.c .= broadcast((μ,Σ)->sqrt.(Σ+abs2.(μ)),model.μ,diag.(model.Σ)) |
55 | 55 | model.likelihood.θ .= broadcast(c->0.5*tanh.(0.5*c)./c,model.likelihood.c) |
56 | 56 | end |
57 | 57 |
|
58 | 58 | function local_updates!(model::SVGP{<:AugmentedLogisticLikelihood,<:AnalyticInference}) |
59 | | - model.likelihood.c .= broadcast((μ,Σ,K̃,κ)->sqrt.(K̃+opt_diag(κ*Σ,κ)+(κ*μ).^2),model.μ,model.Σ,model.K̃,model.κ) |
| 59 | + model.likelihood.c .= broadcast((μ,Σ,K̃,κ)->sqrt.(K̃+opt_diag(κ*Σ,κ)+abs2.(κ*μ)),model.μ,model.Σ,model.K̃,model.κ) |
60 | 60 | model.likelihood.θ .= broadcast(c->0.5*tanh.(0.5*c)./c,model.likelihood.c) |
61 | 61 | end |
62 | 62 |
|
|
92 | 92 |
|
93 | 93 | function expecLogLikelihood(model::VGP{AugmentedLogisticLikelihood{T}}) where T |
94 | 94 | tot = -model.nLatent*(0.5*model.nSample*logtwo) |
95 | | - tot += sum(broadcast((μ,y,θ,Σ)->0.5.*(sum(μ.*y)-dot(θ,Σ+μ.^2)), |
| 95 | + tot += sum(broadcast((μ,y,θ,Σ)->0.5.*(sum(μ.*y)-dot(θ,Σ+abs2.(μ))), |
96 | 96 | model.μ,model.y,model.likelihood.θ,diag.(model.Σ))) |
97 | 97 | return tot |
98 | 98 | end |
99 | 99 |
|
100 | 100 | function expecLogLikelihood(model::SVGP{AugmentedLogisticLikelihood{T}}) where T |
101 | | - tot = -model.nLatent*(0.5*model.nSample*logtwo) |
102 | | - tot += sum(broadcast((κμ,y,θ,κΣκ,K̃)->0.5.*(sum(κμ.*y)-dot(θ,K̃+κΣκ+κμ.^2))), |
| 101 | + tot = -model.nLatent*(0.5*model.inference.nSamplesUsed*logtwo) |
| 102 | + tot += sum(broadcast((κμ,y,θ,κΣκ,K̃)->0.5.*(sum(κμ.*y)-dot(θ,K̃+κΣκ+abs2.(κμ)))), |
103 | 103 | model.κ.*model.μ,model.y,model.likelihood.θ,opt_diag(model.κ.*model.Σ,model.κ'),model.K̃) |
104 | 104 | return model.inference.ρ*tot |
105 | 105 | end |
|
0 commit comments