Skip to content

Commit a52eb6f

Browse files
authored
Switch to easier format for interactions (#45)
* Switch to easier format for interactions * Added test for interactions validation * Revisions from Tim, v0.2
1 parent 96c0012 commit a52eb6f

File tree

4 files changed

+46
-36
lines changed

4 files changed

+46
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GaussianMixtureAlignment"
22
uuid = "f2431ed1-b9c2-4fdb-af1b-a74d6c93b3b3"
33
authors = ["Tom McGrath <[email protected]> and contributors"]
4-
version = "0.1.11"
4+
version = "0.2"
55

66
[deps]
77
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"

src/gogma/bounds.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
loose_distance_bounds(x::AbstractGaussian, y::AbstractGaussian, args...) = loose_distance_bounds(x.μ, y.μ, args...)
22
tight_distance_bounds(x::AbstractGaussian, y::AbstractGaussian, args...) = tight_distance_bounds(x.μ, y.μ, args...)
33

4+
function validate_interactions(interactions::Dict{Tuple{K,K},V}) where {K,V<:Number}
5+
for (k1,k2) in keys(interactions)
6+
if k1 != k2
7+
if haskey(interactions, (k2,k1))
8+
return false
9+
end
10+
end
11+
end
12+
return true
13+
end
14+
415
# prepare pairwise values for `σx^2 + σy^2` and `ϕx * ϕy` for all gaussians in `gmmx` and `gmmy`
516
function pairwise_consts(gmmx::AbstractIsotropicGMM, gmmy::AbstractIsotropicGMM, interactions=nothing)
617
t = promote_type(numbertype(gmmx),numbertype(gmmy))
@@ -14,26 +25,31 @@ function pairwise_consts(gmmx::AbstractIsotropicGMM, gmmy::AbstractIsotropicGMM,
1425
return pσ, pϕ
1526
end
1627

17-
function pairwise_consts(mgmmx::AbstractMultiGMM{N,T,K}, mgmmy::AbstractMultiGMM{N,S,K}, interactions::Union{Nothing,Dict{K,Dict{K,V}}}=nothing) where {N,T,S,K,V <: Number}
28+
function pairwise_consts(mgmmx::AbstractMultiGMM{N,T,K}, mgmmy::AbstractMultiGMM{N,S,K}, interactions::Union{Nothing,Dict{Tuple{K,K},V}}=nothing) where {N,T,S,K,V <: Number}
1829
t = promote_type(numbertype(mgmmx),numbertype(mgmmy), isnothing(interactions) ? numbertype(mgmmx) : V)
1930
xkeys = keys(mgmmx.gmms)
2031
ykeys = keys(mgmmy.gmms)
2132
if isnothing(interactions)
22-
interactions = Dict{K, Dict{K, t}}()
33+
interactions = Dict{Tuple{K,K},t}()
2334
for key in xkeys ykeys
24-
interactions[key] = Dict{K, t}(key => one(t))
35+
interactions[(key,key)] = one(t)
2536
end
37+
else
38+
@assert validate_interactions(interactions) "Interactions must not include redundant key pairs (i.e. (k1,k2) and (k2,k1))"
2639
end
2740
mpσ, mpϕ = Dict{K, Dict{K, Matrix{t}}}(), Dict{K, Dict{K,Matrix{t}}}()
28-
for key1 in keys(interactions)
41+
ukeys = unique(Iterators.flatten(keys(interactions)))
42+
for key1 in ukeys
2943
if key1 xkeys
3044
push!(mpσ, key1 => Dict{K, Matrix{t}}())
3145
push!(mpϕ, key1 => Dict{K, Matrix{t}}())
32-
for key2 in keys(interactions[key1])
33-
if key2 ykeys
46+
for key2 in ukeys
47+
keypair = (key1,key2)
48+
keypair = haskey(interactions, keypair) ? keypair : (key2,key1)
49+
if key2 ykeys && haskey(interactions, keypair)
3450
pσ, pϕ = pairwise_consts(mgmmx.gmms[key1], mgmmy.gmms[key2])
3551
push!(mpσ[key1], key2 => pσ)
36-
push!(mpϕ[key1], key2 => interactions[key1][key2] .* pϕ)
52+
push!(mpϕ[key1], key2 => interactions[keypair] .* pϕ)
3753
end
3854
end
3955
if isempty(mpσ[key1])

src/gogma/overlap.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ end
104104

105105
function force!(f::AbstractVector, x::AbstractIsotropicGaussian, y::AbstractIsotropicGMM, pσ=nothing, pϕ=nothing; kwargs...)
106106
if isnothing(pσ) && isnothing(pϕ)
107-
= x.σ^2 .+ [gy.σ^2 for gy in y.gaussians]
108-
= [ x.ϕ * gy.ϕ for gy in y.gaussians]
107+
xσsq = x.σ^2
108+
= [xσsq + gy.σ^2 for gy in y.gaussians]
109+
= [x.ϕ * gy.ϕ for gy in y.gaussians]
109110
end
110111
for (gy, s, w) in zip(y.gaussians, pσ, pϕ)
111112
force!(f, x, gy, s, w; kwargs...)
@@ -121,10 +122,8 @@ function force!(f::AbstractVector, x::AbstractIsotropicGMM, y::AbstractIsotropic
121122
end
122123
end
123124

124-
function force!(f::AbstractVector, x::AbstractMultiGMM, y::AbstractMultiGMM, mpσ=nothing, mpϕ=nothing; interactions=nothing)
125-
if isnothing(mpσ) && isnothing(mpϕ)
126-
mpσ, mpϕ = pairwise_consts(x, y, interactions)
127-
end
125+
function force!(f::AbstractVector, x::AbstractMultiGMM, y::AbstractMultiGMM; interactions=nothing)
126+
mpσ, mpϕ = pairwise_consts(x, y, interactions)
128127
for k1 in keys(mpσ)
129128
for k2 in keys(mpσ[k1])
130129
# don't pass coef as a keyword argument, since the interaction coefficient is baked into mpϕ

test/runtests.jl

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,20 @@ end
158158
:negative => IsotropicGMM([ch_g]),
159159
:steric => IsotropicGMM(s_gs)
160160
))
161+
# interaction validation
161162
interactions = Dict(
162-
:positive => Dict(
163-
:positive => -1.0,
164-
:negative => 1.0,
165-
),
166-
:negative => Dict(
167-
:positive => 1.0,
168-
:negative => -1.0,
169-
),
170-
:steric => Dict(
171-
:steric => -1.0,
172-
),
163+
(:positive, :negative) => 1.0,
164+
(:negative, :positive) => 1.0,
165+
(:positive, :positive) => -1.0,
166+
(:negative, :negative) => -1.0,
167+
(:steric, :steric) => -1.0,
168+
)
169+
@test_throws AssertionError GMA.pairwise_consts(mgmmx, mgmmy, interactions)
170+
interactions = Dict(
171+
(:positive, :negative) => 1.0,
172+
(:positive, :positive) => -1.0,
173+
(:negative, :negative) => -1.0,
174+
(:steric, :steric) => -1.0,
173175
)
174176
randtform = AffineMap(RotationVec*0.1rand(3)...), SVector{3}(0.1*rand(3)...))
175177
res = gogma_align(randtform(mgmmx), mgmmy; interactions=interactions, maxsplits=5e3, nextblockfun=GMA.randomblock)
@@ -208,17 +210,10 @@ end
208210
fliptform = AffineMap(RotationVec(π,0,0),[0,0,3]) AffineMap(RotationVec(0,0,π),[0,0,0])
209211
mgmmy = fliptform(mgmmy)
210212
interactions = Dict(
211-
:positive => Dict(
212-
:positive => -1.0,
213-
:negative => 1.0,
214-
),
215-
:negative => Dict(
216-
:positive => 1.0,
217-
:negative => -1.0,
218-
),
219-
:steric => Dict(
220-
:steric => -1.0,
221-
),
213+
(:positive, :negative) => 1.0,
214+
(:positive, :positive) => -1.0,
215+
(:negative, :negative) => -1.0,
216+
(:steric, :steric) => -1.0,
222217
)
223218
f = zeros(3)
224219
force!(f, mgmmx, mgmmy; interactions=interactions)

0 commit comments

Comments
 (0)