Skip to content

Commit 34b026e

Browse files
authored
Fixes issues with Inducing Points gradients (#71)
* Renamed reverse to zygote and other name simplification * Correction for Z gradients * Work on making the hyperparameter gradients correct * Large changes to allow updating everything at once * Remove PDMats for Cholesky * Fixed ELBO issues * General changes * All fixed but with ugly hacks * Added a version for non-sparse GP * Additional fixes * Fixes * More fixes * FIxes 3 * Fixes 4 * Added tests for utils and fixes * Fixed prior mean issues * WIP Z_opt * Fixes on nomenclature and optimization * More corrections for logistic softmax * Fixed most of the bugs * Fix Zygote issue * Solve issue with kernelmatrix_diag and besselk * Updated version of KernelFunctions * Worked on the distribution module * Fixing rules and removing inducing points module * Removed InducingPoints * Removing last traces of PolyaGammaDist * Last fixes * Moved files appropriately * Correct typing logistic * Reintroduce two\pi * Update compat to 1.5 and remove Manifest * Fix bug negative binomial * Update ci.yml * Update Project.toml * Change Inference to AbstractInference and Likelihood to AbstractLikelihood * Cleaning datacontainer * Large docs updates * Fixing sampling and data container
1 parent 2361c93 commit 34b026e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+1730
-2366
lines changed

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
style = "blue"

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
matrix:
1515
version:
1616
- '1'
17-
- '1.3'
17+
- '1.6'
1818
- 'nightly'
1919
os:
2020
- ubuntu-latest

Project.toml

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
name = "AugmentedGaussianProcesses"
22
uuid = "38eea1fd-7d7d-5162-9d08-f89d0f2e271e"
33
authors = ["Theo Galy-Fajou <theo.galyfajou@gmail.com>"]
4-
version = "0.9.4"
4+
version = "0.10.0"
55

66
[deps]
77
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
8-
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
9-
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
109
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
11-
DeterminantalPointProcesses = "4d968f93-c0cd-4b7f-b189-b034d1a24a0e"
1210
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
1311
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1412
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
1513
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1614
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
15+
InducingPoints = "b4bd816d-b975-4295-ac05-5f2992945579"
1716
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
1817
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1918
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
20-
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
21-
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
2219
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
2320
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2421
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
@@ -32,24 +29,20 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3229

3330
[compat]
3431
AdvancedHMC = "0.2.13"
35-
Clustering = "0.13.3, 0.14"
36-
DataStructures = "0.17, 0.18"
37-
DeterminantalPointProcesses = "0.1.0"
3832
Distances = "0.8, 0.9, 0.10"
3933
Distributions = "0.21.5, 0.22, 0.23, 0.24"
4034
FastGaussQuadrature = "0.4"
41-
Flux = "0.10, 0.11"
35+
Flux = "0.10, 0.11, 0.12"
4236
ForwardDiff = "0.10"
43-
KernelFunctions = "0.5, 0.6, 0.7, 0.8"
37+
InducingPoints = "0.1"
38+
KernelFunctions = "0.8, 0.9"
4439
MCMCChains = "0.3.15, 2.0, 3.0, 4.0"
45-
MLDataUtils = "0.5"
46-
PDMats = "0.10, 0.11"
4740
ProgressMeter = "1"
4841
RecipesBase = "1.0, 1.1"
4942
Reexport = "0.2, 1"
5043
SimpleTraits = "0.9"
5144
SpecialFunctions = "0.9, 0.10, 1"
5245
StatsBase = "0.32, 0.33"
5346
StatsFuns = "0.8, 0.9"
54-
Zygote = "0.5, 0.6"
55-
julia = "1.3"
47+
Zygote = "0.6.7"
48+
julia = "1.6"

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,6 @@ makedocs(modules = [AugmentedGaussianProcesses],
7979
deploydocs(
8080
deps = Deps.pip("mkdocs", "python-markdown-math"),
8181
repo = "github.com/theogf/AugmentedGaussianProcesses.jl.git",
82-
target = "build"
82+
target = "build",
83+
push_preview = true,
8384
)

docs/src/template_likelihood.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ See all functions you need to implement
1111
1212
1313
"""
14-
struct TemplateLikelihood{T<:Real,A<:AbstractVector{T}} <: Likelihood{T}
14+
struct TemplateLikelihood{T<:Real,A<:AbstractVector{T}} <: AbstractLikelihood{T}
1515
## Additional parameters can be added
1616
θ::A
1717
function TemplateLikelihood{T}() where {T<:Real}

src/AugmentedGaussianProcesses.jl

Lines changed: 80 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -4,102 +4,94 @@ General Framework for the data augmented Gaussian Processes
44
55
"""
66
module AugmentedGaussianProcesses
7-
const AGP = AugmentedGaussianProcesses; export AGP
8-
export AbstractGP, GP, VGP, SVGP, VStP, MCGP, MOVGP, MOSVGP, MOARGP,OnlineSVGP
9-
export Likelihood, RegressionLikelihood, ClassificationLikelihood, MultiClassLikelihood
10-
export GaussianLikelihood, StudentTLikelihood, LaplaceLikelihood, HeteroscedasticLikelihood
11-
export LogisticLikelihood, BayesianSVM
12-
export SoftMaxLikelihood, LogisticSoftMaxLikelihood
13-
export PoissonLikelihood, NegBinomialLikelihood
14-
export Inference, Analytic, AnalyticVI, AnalyticSVI, GibbsSampling, HMCSampling, MCIntegrationVI, MCIntegrationSVI, QuadratureVI, QuadratureSVI
15-
export NumericalVI, NumericalSVI
16-
export PriorMean, ZeroMean, ConstantMean, EmpiricalMean, AffineMean
17-
#Useful functions
18-
export train!, sample
19-
export predict_f, predict_y, proba_y
20-
export fstar, ELBO
21-
export covariance, diag_covariance, prior_mean
22-
export @augmodel
237

24-
#General modules
25-
using Reexport
26-
using LinearAlgebra
27-
using Random
28-
@reexport using KernelFunctions
29-
using KernelFunctions: ColVecs, RowVecs
30-
using Zygote, ForwardDiff
31-
using Flux: params, destructure
32-
@reexport using Flux.Optimise
33-
using PDMats: PDMat, invquad
34-
using AdvancedHMC
35-
using MCMCChains
36-
using StatsBase
37-
using StatsFuns
38-
using SpecialFunctions
39-
using Distributions
40-
using FastGaussQuadrature: gausshermite
41-
using ProgressMeter, SimpleTraits
42-
#Exported modules
43-
# export KMeansModule
44-
export KMeansInducingPoints
8+
const AGP = AugmentedGaussianProcesses
9+
export AGP
10+
export AbstractGP, GP, VGP, SVGP, VStP, MCGP, MOVGP, MOSVGP, MOARGP, OnlineSVGP # All models
11+
export AbstractLikelihood, RegressionLikelihood, ClassificationLikelihood, MultiClassLikelihood, EventLikelihood # All categories of likelihoods
12+
export GaussianLikelihood, StudentTLikelihood, LaplaceLikelihood, HeteroscedasticLikelihood # Regression Likelihoods
13+
export LogisticLikelihood, BayesianSVM # Classification Likelihoods
14+
export SoftMaxLikelihood, LogisticSoftMaxLikelihood # Multiclass Classification Likelihoods
15+
export PoissonLikelihood, NegBinomialLikelihood # Event Likelihoods
16+
export AbstractInference, Analytic, AnalyticVI, AnalyticSVI # Inference objects
17+
export GibbsSampling, HMCSampling # Sampling inference
18+
export NumericalVI, NumericalSVI, MCIntegrationVI, MCIntegrationSVI, QuadratureVI, QuadratureSVI # Numerical inference
19+
export PriorMean, ZeroMean, ConstantMean, EmpiricalMean, AffineMean # Prior means
20+
#Useful functions
21+
export train!, sample
22+
export predict_f, predict_y, proba_y
23+
export fstar
24+
export ELBO
25+
export covariance, diag_covariance, prior_mean
26+
export @augmodel
4527

46-
#Useful functions and module
47-
include(joinpath("functions", "PGSampler.jl"))
48-
include(joinpath("functions", "GIGSampler.jl"))
49-
include(joinpath("functions", "lap_transf_dist.jl"))
50-
#include("functions/PerturbativeCorrection.jl")
51-
# include("functions/GPAnalysisTools.jl")
52-
# include("functions/IO_model.jl")
53-
#Custom modules
54-
using .PGSampler
55-
using .GIGSampler
28+
#General modules
29+
using Reexport
30+
using LinearAlgebra
31+
using Random
32+
@reexport using KernelFunctions
33+
using KernelFunctions: ColVecs, RowVecs
34+
using Zygote, ForwardDiff
35+
using ChainRulesCore: ChainRulesCore, NO_FIELDS, DoesNotExist
36+
using Flux: params, destructure
37+
@reexport using Flux.Optimise
38+
using AdvancedHMC
39+
using MCMCChains
40+
using StatsBase
41+
@reexport using InducingPoints
42+
using StatsFuns
43+
using SpecialFunctions
44+
using Distributions:
45+
Distributions, Distribution,
46+
dim, cov, mean, var,
47+
pdf, logpdf, loglikelihood,
48+
Normal, Poisson, NegativeBinomial, InverseGamma, Laplace, MvNormal, Gamma
49+
using FastGaussQuadrature: gausshermite
50+
using ProgressMeter, SimpleTraits
5651

57-
include(joinpath("inducingpoints" , "InducingPoints.jl"))
58-
@reexport using .InducingPoints
52+
#Include custom module for additional distributions
53+
include(joinpath("ComplementaryDistributions", "ComplementaryDistributions.jl"))
54+
using .ComplementaryDistributions
5955

60-
# using .PerturbativeCorrection
61-
# using .GPAnalysisTools
62-
# using .IO_model
56+
# Main classes
57+
abstract type AbstractInference{T<:Real} end
58+
abstract type VariationalInference{T} <: AbstractInference{T} end
59+
abstract type SamplingInference{T} <: AbstractInference{T} end
60+
abstract type AbstractLikelihood{T<:Real} end
61+
abstract type AbstractLatent{T<:Real,Tpr,Tpo} end
6362

63+
include(joinpath("mean", "priormean.jl"))
64+
include(joinpath("data", "datacontainer.jl"))
65+
include(joinpath("functions", "utils.jl"))
6466

65-
# Main classes
66-
abstract type Inference{T<:Real} end
67-
abstract type VariationalInference{T} <: Inference{T} end
68-
abstract type SamplingInference{T} <: Inference{T} end
69-
abstract type Likelihood{T<:Real} end
70-
abstract type AbstractLatent{T<:Real,Tpr,Tpo} end
67+
# Models
68+
include(joinpath("models", "AbstractGP.jl"))
69+
include(joinpath("gpblocks", "latentgp.jl"))
70+
include(joinpath("models", "GP.jl"))
71+
include(joinpath("models", "VGP.jl"))
72+
include(joinpath("models", "MCGP.jl"))
73+
include(joinpath("models", "SVGP.jl"))
74+
include(joinpath("models", "VStP.jl"))
75+
include(joinpath("models", "MOSVGP.jl"))
76+
include(joinpath("models", "MOVGP.jl"))
77+
include(joinpath("models", "OnlineSVGP.jl"))
78+
include(joinpath("models", "single_output_utils.jl"))
79+
include(joinpath("models", "multi_output_utils.jl"))
7180

72-
include(joinpath("mean", "priormean.jl"))
73-
include(joinpath("data", "datacontainer.jl"))
74-
include(joinpath("functions", "utils.jl"))
81+
include(joinpath("inference", "inference.jl"))
82+
include(joinpath("likelihood", "likelihood.jl"))
83+
include(joinpath("likelihood", "generic_likelihood.jl"))
7584

76-
# Models
77-
include(joinpath("models", "AbstractGP.jl"))
78-
include(joinpath("gpblocks", "latentgp.jl"))
79-
include(joinpath("models", "GP.jl"))
80-
include(joinpath("models", "VGP.jl"))
81-
include(joinpath("models", "MCGP.jl"))
82-
include(joinpath("models", "SVGP.jl"))
83-
include(joinpath("models", "VStP.jl"))
84-
include(joinpath("models", "MOSVGP.jl"))
85-
include(joinpath("models", "MOVGP.jl"))
86-
include(joinpath("models", "OnlineSVGP.jl"))
87-
include(joinpath("models", "single_output_utils.jl"))
88-
include(joinpath("models", "multi_output_utils.jl"))
85+
include(joinpath("functions", "KLdivergences.jl"))
86+
include(joinpath("functions", "ELBO.jl"))
87+
include(joinpath("data", "utils.jl"))
88+
include(joinpath("functions", "plotting.jl"))
8989

90-
include(joinpath("inference", "inference.jl"))
91-
include(joinpath("likelihood", "likelihood.jl"))
90+
# Training and prediction functions
91+
include(joinpath("training", "training.jl"))
92+
include(joinpath("training", "onlinetraining.jl"))
93+
include(joinpath("hyperparameter", "autotuning.jl"))
94+
include(joinpath("training", "predictions.jl"))
95+
include("ar_predict.jl")
9296

93-
include(joinpath("likelihood", "generic_likelihood.jl"))
94-
95-
include(joinpath("functions", "KLdivergences.jl"))
96-
include(joinpath("data", "utils.jl"))
97-
include(joinpath("functions", "plotting.jl"))
98-
99-
# Training and prediction functions
100-
include(joinpath("training", "training.jl"))
101-
include(joinpath("training", "onlinetraining.jl"))
102-
include(joinpath("hyperparameter", "autotuning.jl"))
103-
include(joinpath("training", "predictions.jl"))
104-
include("ar_predict.jl")
10597
end #End Module
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module ComplementaryDistributions
2+
3+
using Distributions
4+
using Random
5+
using SpecialFunctions
6+
using StatsFuns: twoπ
7+
8+
export GeneralizedInverseGaussian, PolyaGamma, LaplaceTransformDistribution
9+
include("generalizedinversegaussian.jl")
10+
include("polyagamma.jl")
11+
include("lap_transf_dist.jl")
12+
13+
end

src/functions/GIGSampler.jl renamed to src/ComplementaryDistributions/generalizedinversegaussian.jl

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,55 @@
1-
"""Module for a Generalized Inverse Gaussian Sampler"""
2-
module GIGSampler
3-
4-
using Distributions
5-
using SpecialFunctions
6-
7-
export GeneralizedInverseGaussian
81

92
"""Sampler object"""
103
struct GeneralizedInverseGaussian{T<:Real} <: Distributions.ContinuousUnivariateDistribution
11-
a::T
4+
a::T
125
b::T
136
p::T
14-
function GeneralizedInverseGaussian{T}(a::T, b::T, p::T) where T
7+
function GeneralizedInverseGaussian{T}(a::T, b::T, p::T) where {T}
158
Distributions.@check_args(GeneralizedInverseGaussian, a > zero(a) && b > zero(b))
16-
new{T}(a, b, p)
9+
return new{T}(a, b, p)
1710
end
1811
end
1912

20-
function GeneralizedInverseGaussian(a::T, b::T, p::T) where T
21-
GeneralizedInverseGaussian{T}(a::T, b::T, p::T)
13+
function GeneralizedInverseGaussian(a::T, b::T, p::T) where {T}
14+
return GeneralizedInverseGaussian{T}(a::T, b::T, p::T)
2215
end
2316

2417
Distributions.params(d::GeneralizedInverseGaussian) = (d.a, d.b, d.p)
25-
@inline Distributions.partype(d::GeneralizedInverseGaussian{T}) where T <: Real = T
26-
18+
@inline Distributions.partype(::GeneralizedInverseGaussian{T}) where {T<:Real} = T
2719

2820
function Distributions.mean(d::GeneralizedInverseGaussian)
2921
a, b, p = params(d)
3022
q = sqrt(a * b)
31-
(sqrt(b) * besselk(p + 1, q)) / (sqrt(a) * besselk(p, q))
23+
return (sqrt(b) * besselk(p + 1, q)) / (sqrt(a) * besselk(p, q))
3224
end
3325

3426
function Distributions.var(d::GeneralizedInverseGaussian)
3527
a, b, p = params(d)
3628
q = sqrt(a * b)
3729
r = besselk(p, q)
38-
(b / a) * ((besselk(p + 2, q) / r) - (besselk(p + 1, q) / r)^2)
30+
return (b / a) * ((besselk(p + 2, q) / r) - (besselk(p + 1, q) / r)^2)
3931
end
4032

41-
Distributions.mode(d::GeneralizedInverseGaussian) = ((d.p - 1) + sqrt((d.p - 1)^2 + d.a * d.b)) / d.a
42-
33+
function Distributions.mode(d::GeneralizedInverseGaussian)
34+
return ((d.p - 1) + sqrt((d.p - 1)^2 + d.a * d.b)) / d.a
35+
end
4336

44-
function Distributions.pdf(d::GeneralizedInverseGaussian{T}, x::Real) where T <: Real
37+
function Distributions.pdf(d::GeneralizedInverseGaussian{T}, x::Real) where {T<:Real}
4538
if x > 0
4639
a, b, p = params(d)
47-
(((a / b)^(p / 2)) / (2 * besselk(p, sqrt(a * b)))) * (x^(p - 1)) * exp(- (a * x + b / x) / 2)
40+
(((a / b)^(p / 2)) / (2 * besselk(p, sqrt(a * b)))) *
41+
(x^(p - 1)) *
42+
exp(-(a * x + b / x) / 2)
4843
else
4944
zero(T)
5045
end
5146
end
5247

53-
function Distributions.logpdf(d::GeneralizedInverseGaussian{T}, x::Real) where T <: Real
48+
function Distributions.logpdf(d::GeneralizedInverseGaussian{T}, x::Real) where {T<:Real}
5449
if x > 0
5550
a, b, p = params(d)
56-
(p / 2) * (log(a) - log(b)) - log(2 * besselk(p, sqrt(a * b))) + (p - 1) * log(x) - (a * x + b / x) / 2
51+
(p / 2) * (log(a) - log(b)) - log(2 * besselk(p, sqrt(a * b))) + (p - 1) * log(x) -
52+
(a * x + b / x) / 2
5753
else
5854
-T(Inf)
5955
end
@@ -71,10 +67,10 @@ function Distributions.rand(d::GeneralizedInverseGaussian)
7167
else
7268
x = _hormann(λ, β)
7369
end
74-
p >= 0 ? x / α : 1 /* x)
70+
return p >= 0 ? x / α : 1 /* x)
7571
end
7672
function _gigqdf(x::Real, λ::Real, β::Real)
77-
(x^- 1)) * exp(-β * (x + 1 / x) / 2)
73+
return (x^- 1)) * exp(-β * (x + 1 / x) / 2)
7874
end
7975

8076
function _hormann::Real, β::Real)
@@ -166,5 +162,3 @@ function _rou_shift(λ::Real, β::Real)
166162
end
167163
end
168164
end
169-
170-
end #module GIGSampler

0 commit comments

Comments
 (0)