Skip to content

Commit 0ede548

Browse files
authored
Provide more interface convenience functions (#47)
This fleshes out the interface of GMMs as vector- or set-like (implementing `push!` and `pop!`), and MultiGMMs as dictionary-like (implementing `valtype`, `get`, `get!`, `delete!`, and `empty!`). This makes it easier to write code without needing to reach into internals. The new interface functions are used to simplify some of the other code.
1 parent 6773f81 commit 0ede548

File tree

5 files changed

+71
-11
lines changed

5 files changed

+71
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GaussianMixtureAlignment"
22
uuid = "f2431ed1-b9c2-4fdb-af1b-a74d6c93b3b3"
33
authors = ["Tom McGrath <[email protected]> and contributors"]
4-
version = "0.2.1"
4+
version = "0.2.2"
55

66
[deps]
77
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"

src/GaussianMixtureAlignment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ isotropic (spherical) Gaussian distributions.
99
REPL help
1010
=========
1111
12-
? followed by an algorith or constructor name will print help to the terminal. See: \n
12+
? followed by an algorithm or constructor name will print help to the terminal. See: \n
1313
\t?IsotropicGaussian \n
1414
\t?IsotropicGMM \n
1515
\t?IsotropicMultiGMM \n

src/draw.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function plot!(gd::GMMDisplay{<:NTuple{<:Any,<:AbstractIsotropicGMM}})
9191
label = gd[:label][]
9292
for (i,gmm) in enumerate(gmms)
9393
col = isnothing(color) ? palette[(i-1) % len + 1] : color
94-
gaussiandisplay!(gd, gmm.gaussians...; display=disp, color=col, label)
94+
gaussiandisplay!(gd, gmm...; display=disp, color=col, label)
9595
end
9696
return gd
9797
end
@@ -103,13 +103,13 @@ function plot!(gd::GMMDisplay{<:NTuple{<:Any,<:AbstractIsotropicMultiGMM{N,T,K}}
103103
palette = gd[:palette][]
104104
allkeys = Set{K}()
105105
for mgmm in mgmms
106-
allkeys = allkeys keys(mgmm.gmms)
106+
allkeys = allkeys keys(mgmm)
107107
end
108108
len = length(allkeys)
109109
for (i,k) in enumerate(allkeys)
110110
col = isnothing(color) ? palette[(i-1) % len + 1] : color
111111
for mgmm in mgmms
112-
haskey(mgmm.gmms, k) && gmmdisplay!(gd, mgmm.gmms[k]; display=disp, color=col, palette=palette, label=string(k))
112+
haskey(mgmm, k) && gmmdisplay!(gd, mgmm[k]; display=disp, color=col, palette=palette, label=string(k))
113113
end
114114
end
115115
return gd

src/gogma/gmm.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import Base: eltype, length, size, getindex, iterate, convert, promote_rule, keys
1+
import Base: eltype, keytype, valtype, length, size, getindex, iterate, convert, promote_rule,
2+
keys, values, push!, pop!, empty!, haskey, get, get!, delete!
23

34
# Type structure: leaving things open for adding anisotropic Gaussians and GMMs
45

@@ -41,7 +42,10 @@ iterate(gmm::AbstractSingleGMM) = iterate(gmm.gaussians)
4142
iterate(gmm::AbstractSingleGMM, i) = iterate(gmm.gaussians, i)
4243
size(gmm::AbstractSingleGMM{N,T}) where {N,T} = (length(gmm.gaussians), N)
4344
size(gmm::AbstractSingleGMM{N,T}, idx::Int) where {N,T} = (length(gmm.gaussians), N)[idx]
44-
eltype(gmm::AbstractSingleGMM) = eltype(gmm.gaussians);
45+
eltype(gmm::AbstractSingleGMM) = eltype(gmm.gaussians)
46+
push!(gmm::AbstractSingleGMM, g::AbstractGaussian) = push!(gmm.gaussians, g)
47+
pop!(gmm::AbstractSingleGMM) = pop!(gmm.gaussians)
48+
empty!(gmm::AbstractSingleGMM) = empty!(gmm.gaussians)
4549

4650
coords(gmm::AbstractSingleGMM) = hcat([g.μ for g in gmm.gaussians]...)
4751
weights(gmm::AbstractSingleGMM) = [g.ϕ for g in gmm.gaussians]
@@ -54,7 +58,16 @@ iterate(mgmm::AbstractMultiGMM) = iterate(mgmm.gmms)
5458
iterate(mgmm::AbstractMultiGMM, i) = iterate(mgmm.gmms, i)
5559
size(mgmm::AbstractMultiGMM{N,T,K}) where {N,T,K} = (length(mgmm.gmms), N)
5660
size(mgmm::AbstractMultiGMM{N,T,K}, idx::Int) where {N,T,K} = (length(mgmm.gmms), N)[idx]
57-
eltype(mgmm::AbstractMultiGMM) = eltype(mgmm.gmms);
61+
eltype(mgmm::AbstractMultiGMM) = eltype(mgmm.gmms)
62+
eltype(::Type{MGMM}) where MGMM<:AbstractMultiGMM = Pair{keytype(MGMM),valtype(MGMM)}
63+
keytype(mgmm::AbstractMultiGMM) = keytype(typeof(mgmm))
64+
keytype(::Type{<:AbstractMultiGMM{N,T,K}}) where {N,T,K} = K
65+
valtype(mgmm::AbstractMultiGMM) = valtype(mgmm.gmms)
66+
haskey(mgmm::AbstractMultiGMM, k) = haskey(mgmm.gmms, k)
67+
get(mgmm::AbstractMultiGMM, k, default) = get(mgmm.gmms, k, default)
68+
get!(::Type{V}, mgmm::AbstractMultiGMM, k) where V = get!(V, mgmm.gmms, k)
69+
delete!(mgmm::AbstractMultiGMM, k) = delete!(mgmm.gmms, k)
70+
empty!(mgmm::AbstractMultiGMM) = empty!(mgmm.gmms)
5871

5972
coords(mgmm::AbstractMultiGMM) = hcat([coords(gmm) for (k,gmm) in mgmm.gmms]...)
6073
weights(mgmm::AbstractMultiGMM) = vcat([weights(gmm) for (k,gmm) in mgmm.gmms]...)
@@ -79,7 +92,7 @@ end
7992

8093
IsotropicGaussian(g::AbstractIsotropicGaussian) = IsotropicGaussian(g.μ, g.σ, g.ϕ)
8194

82-
convert(::Type{IsotropicGaussian{N,T}}, g::AbstractIsotropicGaussian) where {N,T} = IsotropicGaussian(SVector{N,T}(g.μ), T(g.σ), T(g.ϕ))
95+
convert(::Type{IsotropicGaussian{N,T}}, g::AbstractIsotropicGaussian) where {N,T} = IsotropicGaussian{N,T}(g.μ, g.σ, g.ϕ)
8396
promote_rule(::Type{IsotropicGaussian{N,T}}, ::Type{IsotropicGaussian{N,S}}) where {N,T<:Real,S<:Real} = IsotropicGaussian{N,promote_type(T,S)}
8497

8598
(g::IsotropicGaussian)(pos::AbstractVector) = exp(-sum(abs2, pos-g.μ)/(2*g.σ^2))*g.ϕ
@@ -92,11 +105,13 @@ struct IsotropicGMM{N,T} <: AbstractIsotropicGMM{N,T}
92105
end
93106

94107
IsotropicGMM(gmm::AbstractIsotropicGMM) = IsotropicGMM(gmm.gaussians)
108+
IsotropicGMM{N,T}() where {N,T} = IsotropicGMM{N,T}(IsotropicGaussian{N,T}[])
95109

96-
convert(t::Type{IsotropicGMM}, gmm::AbstractIsotropicGMM) = t(gmm.gaussians)
110+
convert(::Type{GMM}, gmm::AbstractIsotropicGMM) where GMM<:IsotropicGMM = GMM(gmm.gaussians)
97111
promote_rule(::Type{IsotropicGMM{N,T}}, ::Type{IsotropicGMM{N,S}}) where {T,S,N} = IsotropicGMM{N,promote_type(T,S)}
112+
eltype(::Type{IsotropicGMM{N,T}}) where {N,T} = IsotropicGaussian{N,T}
98113

99-
(gmm::IsotropicGMM)(pos::AbstractVector) = sum(g(pos) for g in gmm.gaussians)
114+
(gmm::IsotropicGMM)(pos::AbstractVector) = sum(g(pos) for g in gmm)
100115

101116
"""
102117
A collection of labeled `IsotropicGMM`s, to each be considered separately during an alignment procedure. That is,
@@ -110,6 +125,7 @@ IsotropicMultiGMM(gmm::AbstractIsotropicMultiGMM) = IsotropicMultiGMM(gmm.gmms)
110125

111126
convert(t::Type{IsotropicMultiGMM}, mgmm::AbstractIsotropicMultiGMM) = t(mgmm.gmms)
112127
promote_rule(::Type{IsotropicMultiGMM{N,T,K}}, ::Type{IsotropicMultiGMM{N,S,K}}) where {N,T,S,K} = IsotropicMultiGMM{N,promote_type(T,S),K}
128+
valtype(::Type{IsotropicMultiGMM{N,T,K}}) where {N,T,K} = IsotropicGMM{N,T}
113129

114130
# descriptive display
115131
# TODO update to display type parameters, make use of supertypes, etc

test/runtests.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,50 @@ end
9696
end
9797
end
9898

99+
@testset "GMM interface" begin
100+
tetrahedral = [
101+
[0.,0.,1.],
102+
[sqrt(8/9), 0., -1/3],
103+
[-sqrt(2/9),sqrt(2/3),-1/3],
104+
[-sqrt(2/9),-sqrt(2/3),-1/3]
105+
]
106+
ch_g = IsotropicGaussian(tetrahedral[1], 1.0, 1.0)
107+
s_gs = [IsotropicGaussian(x, 0.5, 1.0) for (i,x) in enumerate(tetrahedral)]
108+
gmm = IsotropicGMM(s_gs)
109+
@test length(gmm) == 4
110+
@test gmm[2] == s_gs[2]
111+
@test collect(gmm) == s_gs # tests iterate
112+
@test eltype(gmm) === eltype(typeof(gmm)) === IsotropicGaussian{3,Float64}
113+
@test convert(IsotropicGMM{3,Float32}, gmm) isa IsotropicGMM{3,Float32}
114+
@test_throws DimensionMismatch convert(IsotropicGMM{2,Float64}, gmm)
115+
mgmmx = IsotropicMultiGMM(Dict(
116+
:positive => IsotropicGMM([ch_g]),
117+
:steric => gmm
118+
))
119+
@test keys(mgmmx) == Set([:positive, :steric])
120+
@test keytype(mgmmx) === keytype(typeof(mgmmx)) === Symbol
121+
@test valtype(mgmmx) === valtype(typeof(mgmmx)) === IsotropicGMM{3,Float64}
122+
@test eltype(mgmmx) === eltype(typeof(mgmmx)) === Pair{Symbol, IsotropicGMM{3,Float64}}
123+
@test length(mgmmx) == 2
124+
@test length(mgmmx[:steric]) == 4
125+
@test mgmmx[:steric][2] == s_gs[2]
126+
@test collect(mgmmx) == collect(mgmmx.gmms) # tests iterate
127+
@test get!(valtype(mgmmx), mgmmx, :positive) == mgmmx[:positive]
128+
gmm = get!(valtype(mgmmx), mgmmx, :acceptor)
129+
@test isempty(gmm) && gmm isa IsotropicGMM{3,Float64}
130+
push!(gmm, ch_g)
131+
@test length(gmm) == 1
132+
pop!(gmm)
133+
@test isempty(gmm)
134+
push!(gmm, ch_g)
135+
empty!(gmm)
136+
@test isempty(gmm)
137+
delete!(mgmmx, :acceptor)
138+
@test !haskey(mgmmx, :acceptor)
139+
empty!(mgmmx)
140+
@test isempty(mgmmx)
141+
end
142+
99143
@testset "bounds for shrinking searchspace around an optimum" begin
100144
# two sets of points, each forming a 3-4-5 triangle
101145
xpts = [[0.,0.,0.], [3.,0.,0.,], [0.,4.,0.]]

0 commit comments

Comments
 (0)