|
| 1 | +using AugmentedGaussianProcesses; const AGP = AugmentedGaussianProcesses |
| 2 | +using LinearAlgebra, Distributions, Plots |
| 3 | +using BenchmarkTools |
| 4 | +b = 2.0 |
| 5 | +C()=1/(2b) |
| 6 | +g(y) = 0.0 |
| 7 | +α(y) = y^2 |
| 8 | +β(y) = 2*y |
| 9 | +γ(y) = 1.0 |
| 10 | +φ(r) = exp(-sqrt(r)/b) |
| 11 | +∇φ(r) = -exp(-sqrt(r)/b)/(2*b*sqrt(r)) |
| 12 | +ll(y,x) = 0.5*exp(0.5*y*x)*sech(0.5*sqrt(x^2)) |
| 13 | + |
| 14 | +## |
| 15 | +formula = :(p(y|x)=exp(0.5*y*x)*sech(0.5*sqrt(y^2 - 2*y*x + x^2))) |
| 16 | +# formula = :(p(y,x)=exp(0.5*y*x)*sech(0.5*sqrt(0.0 - 0.0*x + x^2))) |
| 17 | +formula.args[2].args[2].args |
| 18 | + |
| 19 | +topargs = formula.args[2].args[2].args |
| 20 | +if topargs[1] == :* |
| 21 | + @show topargs[1] |
| 22 | + global CC = copy(topargs[2]) |
| 23 | + popfirst!(topargs) |
| 24 | + popfirst!(topargs) |
| 25 | +else |
| 26 | + global CC = :0 |
| 27 | +end |
| 28 | +args2 = topargs[1] |
| 29 | +if args2.args[1] == :exp |
| 30 | + gargs = args2.args[2] |
| 31 | + if gargs.args[1] == :* |
| 32 | + deleteat!(gargs.args,findfirst(isequal(:x),gargs.args)) |
| 33 | + else |
| 34 | + @error "BAD BAD BAD" |
| 35 | + end |
| 36 | + global GG = copy(gargs) |
| 37 | + popfirst!(topargs) |
| 38 | +else |
| 39 | + global GG = :0 |
| 40 | +end |
| 41 | +args3 = topargs[1] |
| 42 | +seq = string(args3) |
| 43 | +findh= r"\([^(]*\-.*x.*\+.*x \^ 2[^)]*" |
| 44 | +b = occursin(findh,seq) |
| 45 | +m = match(findh,seq).match |
| 46 | +alphar = r"[^(][^-]*" |
| 47 | +malpha = match(alphar,m).match[1:end-1] |
| 48 | +betar = r"- [^x]*x" |
| 49 | +mbeta = match(betar,m).match[3:end] |
| 50 | +gammar = r"\+ [^x]*x \^ 2" |
| 51 | +mgamma = match(gammar,m).match[3:end] |
| 52 | + |
| 53 | +AA = :($malpha) |
| 54 | +BB = :($(mbeta[1:end-1])) |
| 55 | +GG = :($(mgamma == "x ^ 2" ? "1.0" : mgamma[1:end-5])) |
| 56 | + |
| 57 | +loc = findfirst(findh,seq) |
| 58 | +newseq = seq[1:loc[1]-1]*"r"*seq[(loc[end]+1):end] |
| 59 | +PHI = :($newseq) |
| 60 | +## |
| 61 | +f_lap = :(p(y|x)=0.5/β * exp(- sqrt((y^2 - 2*y*f + f^2))/β)) |
| 62 | +display.(AGP.@augmodel NewSuperLaplace Regression (p(y|x)=0.5/β * exp( y * x) * exp(- sqrt((sqrt(y^2) - exp(4.0*y)*x + 2.0*x^2))/β)) β) |
| 63 | +pdfstring = "(0.5 / β) * exp(2*y*x) * exp(-(sqrt((y ^ 2 - 2.0 * y * x) + 1.0*x ^ 2)) / β)" |
| 64 | + |
| 65 | +Gstring |
| 66 | +## |
| 67 | +PhiHstring = match(Regex("(?<=$(AGP.correct_parenthesis(Gstringfull))x\\) \\* ).*"),pdfstring).match |
| 68 | +Hstring = match(r"(?<=\().+\-.*x.*\+.+x \^ 2(?=\))",PhiHstring).match |
| 69 | +locx = findfirst("x ^ 2",PhiHstring) |
| 70 | +count_parenthesis = 1 |
| 71 | +locf = locx[1] |
| 72 | +while count_parenthesis != 0 |
| 73 | + global locf = locf - 1 |
| 74 | + println(PhiHstring[locf]) |
| 75 | + if PhiHstring[locf] == ')' |
| 76 | + global count_parenthesis += 1 |
| 77 | + elseif PhiHstring[locf] == '(' |
| 78 | + global count_parenthesis -= 1 |
| 79 | + end |
| 80 | +end |
| 81 | +Hstring = PhiHstring[(locf+1):locx[end]] |
| 82 | + |
| 83 | +alphar = r"[^(][^-]*" |
| 84 | +alpha_string = match(alphar,Hstring).match[1:end-1] |
| 85 | +# betar = r"(?>=- )[^x]+(?= * x)" |
| 86 | +betar = r"(?<=- )[^x]+(?= * x)" |
| 87 | +mbeta = match(betar,Hstring).match |
| 88 | +while last(mbeta) == ' ' || last(mbeta) == '*' |
| 89 | + global mbeta = mbeta[1:end-1] |
| 90 | +end |
| 91 | +mbeta |
| 92 | +gammar = r"(?<=\+ )[^x]*(?=x \^ 2)" |
| 93 | +mgamma = match(gammar,m).match == "" ? "1.0" : match(gammar,m).match |
| 94 | +## |
| 95 | +findnext(isequal(')'),PhiHstring,locx[end]) |
| 96 | +code = Meta.parse(PhiHstring) |
| 97 | +code.args |
| 98 | + |
| 99 | +S = code.args[2].args[2].args[2].args[2].args[2].args |
| 100 | +S = code.args[2].args[2].args[2].args[2].args |
| 101 | +for args in S.args |
| 102 | + if args == :(x ^ 2) |
| 103 | + @show "BLAH" |
| 104 | + end |
| 105 | +end |
| 106 | +S = string(code.args[2].args[2]) |
| 107 | +Hstring = match(r"(?<=\().*x \^ 2.*\-.*x.*\+.*(?=\))",S) |
| 108 | + |
| 109 | +## |
| 110 | + |
| 111 | +txt = AGP.@augmodel("NewLaplace","Regression",C,g,α,β,γ,φ,∇φ) |
| 112 | + |
| 113 | +# NewLaplaceLikelihood() |> display |
| 114 | +N = 500 |
| 115 | +σ = 1.0 |
| 116 | +X = sort(rand(N,1),dims=1) |
| 117 | +K = kernelmatrix(X,RBFKernel(0.1))+1e-4*I |
| 118 | +L = Matrix(cholesky(K).L) |
| 119 | +y_true = rand(MvNormal(K)) |
| 120 | +y = y_true+randn(length(y_true))*2 |
| 121 | +p = scatter(X[:],y,lab="data") |
| 122 | +NewLaplaceLikelihood() |> display |
| 123 | +m = VGP(X,y,RBFKernel(0.5),NewLaplaceLikelihood(),AnalyticVI(),optimizer=false) |
| 124 | +train!(m,iterations=100) |
| 125 | +y_p, sig_p = proba_y(m,collect(0:0.01:1)) |
| 126 | + |
| 127 | +m2 = VGP(X,y,RBFKernel(0.5),LaplaceLikelihood(b),AnalyticVI(),optimizer=false) |
| 128 | +train!(m2,iterations=100) |
| 129 | +y_p2, sig_p2 = proba_y(m2,collect(0:0.01:1)) |
| 130 | + |
| 131 | +plot!(X,y_true,lab="truth") |
| 132 | + |
| 133 | +plot!(collect(0:0.01:1),y_p,lab="Auto Laplace") |
| 134 | +plot!(collect(0:0.01:1),y_p2,lab="Classic Laplace") |> display |
| 135 | + |
| 136 | + |
| 137 | +# @btime train!($m,iterations=1) |
| 138 | +# @btime train!($m2,iterations=1) |
| 139 | + |
| 140 | +### |
0 commit comments