Skip to content

Commit a7c576a

Browse files
authored
Merge pull request #67 from theogf/Additionalfixes
more general fixes
2 parents 089025c + 8916aa9 commit a7c576a

35 files changed

+208
-300
lines changed

.travis.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@ os:
44
- linux
55
- osx
66
julia:
7-
- 1.2
8-
- 1.4
7+
- 1.3
8+
- 1
99
- nightly
1010
notifications:
1111
email: false
1212
after_success:
1313
# push coverage results to Coveralls
14-
- if [[ $TRAVIS_JULIA_VERSION = 1.4 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
14+
- if [[ $TRAVIS_JULIA_VERSION = 1 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
1515
julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Coveralls.submit(process_folder())';
1616
fi
1717
jobs:
1818
include:
1919
- stage: "Documentation"
20-
julia: 1.4
20+
julia: 1
2121
os: linux
2222
script:
2323
- export DOCUMENTER_DEBUG=true

Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "AugmentedGaussianProcesses"
22
uuid = "38eea1fd-7d7d-5162-9d08-f89d0f2e271e"
33
authors = ["Theo Galy-Fajou <theo.galyfajou@gmail.com>"]
4-
version = "0.8.5"
4+
version = "0.9.0"
55

66
[deps]
77
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
88
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
99
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1010
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1111
DeterminantalPointProcesses = "4d968f93-c0cd-4b7f-b189-b034d1a24a0e"
12+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
1213
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1314
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
1415
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -25,7 +26,6 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
2526
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2627
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
2728
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
28-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2929
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3030
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3131
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -51,7 +51,7 @@ RecipesBase = "1.0, 1.1"
5151
Reexport = "0.2"
5252
SimpleTraits = "0.9"
5353
SpecialFunctions = "0.9, 0.10"
54-
StatsBase = "0.32.0, 0.33"
55-
StatsFuns = "0.8.0, 0.9"
56-
Zygote = "0.4, 0.5"
57-
julia = "1.2"
54+
StatsBase = "0.32, 0.33"
55+
StatsFuns = "0.8, 0.9"
56+
Zygote = "0.5"
57+
julia = "1.3"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[![Docs Latest](https://img.shields.io/badge/docs-dev-blue.svg)](https://theogf.github.io/AugmentedGaussianProcesses.jl/dev)
44
[![Docs Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://theogf.github.io/AugmentedGaussianProcesses.jl/stable)
5-
[![Build Status](https://travis-ci.org/theogf/AugmentedGaussianProcesses.jl.svg?branch=master)](https://travis-ci.org/theogf/AugmentedGaussianProcesses.jl)
5+
[![Build Status](https://travis-ci.com/theogf/AugmentedGaussianProcesses.jl.svg?branch=master)](https://travis-ci.com/theogf/AugmentedGaussianProcesses.jl)
66
[![Coverage Status](https://coveralls.io/repos/github/theogf/AugmentedGaussianProcesses.jl/badge.svg?branch=master)](https://coveralls.io/github/theogf/AugmentedGaussianProcesses.jl?branch=master)
77
[![DOI](https://zenodo.org/badge/118922202.svg)](https://zenodo.org/badge/latestdoi/118922202)
88

@@ -59,7 +59,7 @@ AugmentedGaussianProcesses.jl is a Julia package in development for **Data Augme
5959

6060
## Install the package
6161

62-
The package requires at least [Julia 1.1](https://julialang.org/)
62+
The package requires at least [Julia 1.3](https://julialang.org/)
6363
Run `julia`, press `]` and type `add AugmentedGaussianProcesses`, it will install the package and all its dependencies.
6464

6565
## Use the package

docs/examples/multiclassgp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end;
6868

6969
function plot_contour(model, σ)
7070
n_grid = 100
71-
global pred_proba, pred, x, y = compute_grid(model, n_grid);
71+
pred_proba, pred, x, y = compute_grid(model, n_grid);
7272
colors = reshape(
7373
[
7474
RGB([pred_proba[model.likelihood.ind_mapping[j]][i] for j in 1:n_class]...)

src/ar_predict.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ks = []
2121
@assert length(y_past) == m.nTask
2222
@assert all(length.(y_past).>= p)
2323
@assert all(m.nDim .== p)
24-
global Xtest = [reshape(y[(end-p+1):end],1,:) for y in y_past]
24+
Xtest = [reshape(y[(end-p+1):end],1,:) for y in y_past]
2525
y_new = [zeros(T, n) for _ in 1:m.nTask]
2626
for i in 1:n
2727
setindex!.(y_new, first.(first.(first(_predict_f(m, Xtest, covf = false)))), i)

src/data/datacontainer.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@ struct DataContainer{
1414
TY<:AbstractVector,
1515
} <: AbstractDataContainer
1616
X::TX # Feature vectors
17-
y::TY # Output (-1,1 for classification, real for regression, matrix for multiclass)
17+
y::TY # Output (-1,1 for classification, real for regression, vector{vector} for multiclass)
1818
nSamples::Int # Number of samples
1919
nDim::Int # Number of features per sample
2020
end
2121

2222
function wrap_data(X::TX, y::TY) where {TX, TY<:AbstractVector{<:Real}}
23-
size(y, 1) == size(X, 1) || error("There is not the same number of samples in X ($(length(TX))) and y ($(size(y, 1)))")
23+
length(y) == length(X) || error("There is not the same number of samples in X ($(length(X))) and y ($(length(y)))")
2424
Tx = eltype(first(X))
2525
Ty = eltype(first(y))
2626
return DataContainer{Tx, TX, Ty, TY}(X, y, length(X), length(first(X)))
2727
end
2828

2929
function wrap_data(X::TX, y::TY) where {TX, TY<:AbstractVector}
30-
size(first(y), 1) == size(X, 1) || error("There is not the same number of samples in X ($(length(TX))) and y ($(size(y, 1)))")
30+
all(length.(y) .== length(X)) || error("There is not the same number of samples in X ($(length(X))) and y ($(length.(y))))")
3131
Tx = eltype(first(X))
3232
Ty = eltype(first(y))
3333
return DataContainer{Tx, TX, Ty, TY}(X, y, length(X), length(first(X)))
@@ -46,7 +46,7 @@ struct MODataContainer{
4646
end
4747

4848
function wrap_modata(X::TX, y::TY) where {TX, TY<:AbstractVector}
49-
all(size.(y, 1) .== size(X, 1)) || error("There is not the same number of samples in X ($(length(TX))) and y ($(size(y, 1)))")
49+
all(length.(y) .== length(X)) || error("There is not the same number of samples in X ($(length(X))) and y ($(length.(y))))")
5050
Tx = eltype(first(X))
5151
return MODataContainer{Tx, TX, TY}(X, y, length(X), length(first(X)), length(y))
5252
end

src/data/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ view_x(d::AbstractDataContainer, indices) = view(d.X, indices)
22
view_y(l::Likelihood, y::AbstractVector, i::AbstractVector) = view(y, i)
33
view_y(l::Likelihood, d::AbstractDataContainer, i::AbstractVector) = view_y(l, output(d), i)
44
view_y(l::Likelihood, d::MODataContainer, i::AbstractVector) = view_y.(l, output(d), Ref(i))
5+
view_y(l::AbstractVector{<:Likelihood}, d::MODataContainer, i::AbstractVector) = view_y.(l, output(d), Ref(i))
6+
57

68
# Verify that the data is self-consistent and consistent with the likelihood ##
79
function check_data!(

src/examples/gpregression.md/GP Regression.md

Lines changed: 0 additions & 178 deletions
This file was deleted.

src/functions/KLdivergences.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
"""Compute the KL Divergence between the GP Prior and the variational distribution"""
2-
GaussianKL(model::AbstractGP) = sum(broadcast(GaussianKL, model.f, Zviews(model)))
1+
"""
2+
KL Divergence between the GP Prior and the variational distribution
3+
"""
4+
GaussianKL(model::AbstractGP) = mapreduce(GaussianKL, +, model.f, Zviews(model))
35

46
GaussianKL(gp::AbstractLatent, X::AbstractVector) = GaussianKL(mean(gp), pr_mean(gp, X), cov(gp), pr_cov(gp))
57

@@ -13,15 +15,17 @@ function GaussianKL(
1315
0.5 * (logdet(K) - logdet(Σ) + tr(K \ Σ) + invquad(K, μ - μ₀) - length(μ))
1416
end
1517

16-
extraKL(model::AbstractGP{T}) where {T} = zero(T)
18+
extraKL(::AbstractGP{T}) where {T} = zero(T)
1719

18-
"""Return the extra KL term containing the divergence with the GP at time t and t+1"""
20+
"""
21+
Extra KL term containing the divergence with the GP at time t and t+1
22+
"""
1923
function extraKL(model::OnlineSVGP{T}) where {T}
2024
KLₐ = zero(T)
2125
for gp in model.f
2226
κₐμ = gp.κₐ * mean(gp)
2327
KLₐ += gp.prev𝓛ₐ
24-
KLₐ += -0.5 * sum(opt_trace.([gp.invDₐ], [gp.K̃ₐ, gp.κₐ * cov(gp) * transpose(gp.κₐ)]))
28+
KLₐ += -0.5 * sum(opt_trace.(Ref(gp.invDₐ), [gp.K̃ₐ, gp.κₐ * cov(gp) * transpose(gp.κₐ)]))
2529
KLₐ += dot(gp.prevη₁, κₐμ) - 0.5 * dot(κₐμ, gp.invDₐ * κₐμ)
2630
end
2731
return KLₐ
@@ -57,21 +61,22 @@ function PoissonKL(
5761
end
5862

5963

60-
"""KL(q(ω)||p(ω)), where q(ω) = PG(b,c) and p(ω) = PG(b,0). θ = 𝑬[ω]"""
64+
"""
65+
KL(q(ω)||p(ω)), where q(ω) = PG(b,c) and p(ω) = PG(b,0). θ = 𝑬[ω]
66+
"""
6167
function PolyaGammaKL(b, c, θ)
6268
dot(b, logcosh.(0.5 * c)) - 0.5 * dot(abs2.(c), θ)
6369
end
6470

65-
6671
"""
6772
Entropy of GIG variables with parameters a,b and p and omitting the derivative d/dpK_p cf <https://en.wikipedia.org/wiki/Generalized_inverse_Gaussian_distribution#Entropy>
6873
"""
6974
function GIGEntropy(a, b, p)
70-
sqrtab = sqrt.(a .* b)
71-
return sum(0.5 * log.(a ./ b)) +
72-
sum(log.(2 * besselk.(p, sqrtab))) +
75+
sqrt_ab = sqrt.(a .* b)
76+
return 0.5 * (sum(log, a) - sum(log, b)) +
77+
mapreduce((p, s) -> log(2 * besselk(p, s), +, p, sqrt_ab)) +
7378
sum(
74-
0.5 * sqrtab ./ besselk.(p, sqrtab) .*
75-
(besselk.(p + 1, sqrtab) + besselk.(p - 1, sqrtab)),
79+
0.5 * sqrt_ab ./ besselk.(p, sqrt_ab) .*
80+
(besselk.(p + 1, sqrt_ab) + besselk.(p - 1, sqrt_ab)),
7681
)
7782
end

0 commit comments

Comments
 (0)