Skip to content

Commit d0c9941

Browse files
authored
Added support for repulsion and interaction maps for MultiGMMs (#39)
* Added support for repulsion and interaction maps for MultiGMMs * v0.1.9
1 parent c926965 commit d0c9941

File tree

8 files changed

+134
-45
lines changed

8 files changed

+134
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GaussianMixtureAlignment"
22
uuid = "f2431ed1-b9c2-4fdb-af1b-a74d6c93b3b3"
33
authors = ["Tom McGrath <[email protected]> and contributors"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[deps]
77
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"

src/distancebounds.jl

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,29 @@ function infbounds(x,y)
77
return (typeinf, typeinf)
88
end
99

10-
function loose_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number}, σᵣ::Number, σₜ::Number)
10+
function loose_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number}, σᵣ::Number, σₜ::Number, maximize::Bool = false)
1111
ubdist = norm(x - y)
1212
γₜ = sqrt3 * σₜ
1313
γᵣ = 2 * sin(min(sqrt3 * σᵣ, π) / 2) * norm(x)
14-
lb, ub = max(ubdist - γₜ - γᵣ, 0), ubdist
14+
lb, ub = maximize ? (max(ubdist - γₜ - γᵣ, 0), ubdist) : (ubddist + γₜ + γᵣ, ubdist)
1515
numtype = promote_type(typeof(lb), typeof(ub))
1616
return numtype(lb), numtype(ub)
1717
end
18-
loose_distance_bounds(x::SVector{3}, y::SVector{3}, R::RotationVec, T::SVector{3}, σᵣ, σₜ
19-
) = (R.sx^2 + R.sy^2 + R.sz^2) > pisq ? infbounds(x,y) : loose_distance_bounds(R*x, y-T, σᵣ, σₜ) # loose_distance_bounds(R*x, y-T, σᵣ, σₜ)
20-
loose_distance_bounds(x::SVector{3}, y::SVector{3}, block::UncertaintyRegion) = loose_distance_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ)
21-
loose_distance_bounds(x::SVector{3}, y::SVector{3}, block::SearchRegion) = loose_distance_bounds(x, y, UncertaintyRegion(block))
18+
loose_distance_bounds(x::SVector{3}, y::SVector{3}, R::RotationVec, T::SVector{3}, σᵣ, σₜ, maximize::Bool = false,
19+
) = (R.sx^2 + R.sy^2 + R.sz^2) > pisq ? infbounds(x,y) : loose_distance_bounds(R*x, y-T, σᵣ, σₜ, maximize) # loose_distance_bounds(R*x, y-T, σᵣ, σₜ)
20+
loose_distance_bounds(x::SVector{3}, y::SVector{3}, block::UncertaintyRegion, maximize::Bool = false) = loose_distance_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, maximize)
21+
loose_distance_bounds(x::SVector{3}, y::SVector{3}, block::SearchRegion, maximize::Bool = false) = loose_distance_bounds(x, y, UncertaintyRegion(block), maximize)
2222

2323

2424
"""
2525
lb, ub = tight_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number}, σᵣ::Number, σₜ::Number)
2626
lb, ub = tight_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number}, R::RotationVec, T<:SVector{3}, σᵣ::Number, σₜ::Number)
2727
28-
Within an uncertainty region, find the bounds on distance between two points x and y.
28+
Within an uncertainty region, find the bounds on distance between two points x and y.
2929
3030
See [Campbell & Peterson, 2016](https://arxiv.org/abs/1603.00150)
3131
"""
32-
function tight_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number}, σᵣ::Number, σₜ::Number)
32+
function tight_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number}, σᵣ::Number, σₜ::Number, maximize::Bool = false)
3333
# prepare positions and angles
3434
xnorm, ynorm = norm(x), norm(y)
3535
if xnorm*ynorm == 0
@@ -42,13 +42,23 @@ function tight_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number},
4242
# upper bound distance at hypercube center
4343
ubdist = norm(x - y)
4444

45-
# lower bound distance from the nearest point on the "spherical cap"
46-
if cosα >= cosβ
47-
lbdist = max(abs(xnorm-ynorm) - sqrt3*σₜ, 0)
45+
if maximize
46+
# this case is intended for situations where the objective function scales negatively with distance\
47+
# lbdist, which will be the further point on the spherical cap, will be larger than ubdist
48+
if cosα + cosβ >= π
49+
lbdist = xnorm + ynorm + sqrt3*σₜ
50+
else
51+
lbdist = (xnorm^2 + ynorm^2 - 2*xnorm*ynorm*(cosα*cosβ-√((1-cosα^2)*(1-cosβ^2)))) + sqrt3*σₜ
52+
end
4853
else
49-
lbdist = try max((xnorm^2 + ynorm^2 - 2*xnorm*ynorm*(cosα*cosβ+√((1-cosα^2)*(1-cosβ^2)))) - sqrt3*σₜ, 0) # law of cosines
50-
catch e # when the argument for the square root is negative (within machine precision of 0, usually)
51-
0
54+
# lower bound distance from the nearest point on the "spherical cap"
55+
if cosα >= cosβ
56+
lbdist = max(abs(xnorm-ynorm) - sqrt3*σₜ, 0)
57+
else
58+
lbdist = try max((xnorm^2 + ynorm^2 - 2*xnorm*ynorm*(cosα*cosβ+√((1-cosα^2)*(1-cosβ^2)))) - sqrt3*σₜ, 0) # law of cosines
59+
catch e # when the argument for the square root is negative (within machine precision of 0, usually)
60+
0
61+
end
5262
end
5363
end
5464

@@ -57,7 +67,7 @@ function tight_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number},
5767
return (numtype(lbdist), numtype(ubdist))
5868
end
5969

60-
tight_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number}, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number
61-
) = (R.sx^2 + R.sy^2 + R.sz^2) > pisq ? infbounds(x,y) : tight_distance_bounds(R*x, y-T, σᵣ, σₜ) # tight_distance_bounds(R*x, y-T, σᵣ, σₜ)
62-
tight_distance_bounds(x::SVector{3}, y::SVector{3}, block::UncertaintyRegion) = tight_distance_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ)
63-
tight_distance_bounds(x::SVector{3}, y::SVector{3}, block::Union{RotationRegion, TranslationRegion}) = tight_distance_bounds(x, y, UncertaintyRegion(block))
70+
tight_distance_bounds(x::SVector{3,<:Number}, y::SVector{3,<:Number}, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, maximize::Bool = false,
71+
) = (R.sx^2 + R.sy^2 + R.sz^2) > pisq ? infbounds(x,y) : tight_distance_bounds(R*x, y-T, σᵣ, σₜ, maximize) # tight_distance_bounds(R*x, y-T, σᵣ, σₜ)
72+
tight_distance_bounds(x::SVector{3}, y::SVector{3}, block::UncertaintyRegion, maximize::Bool = false) = tight_distance_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, maximize)
73+
tight_distance_bounds(x::SVector{3}, y::SVector{3}, block::Union{RotationRegion, TranslationRegion}, maximize::Bool = false) = tight_distance_bounds(x, y, UncertaintyRegion(block), maximize)

src/draw.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const DEFAULT_COLORS = [ # CUD colors: https://jfly.uni-koeln.de/colo
2121
const cosθs = [cos(θ) for θ in θs]
2222
const sinθs = [sin(θ) for θ in θs]
2323

24-
equal_volume_radius(σ, ϕ) = (EQUAL_VOL_CONST*ϕ)^(1/3) * σ
24+
equal_volume_radius(σ, ϕ) = (EQUAL_VOL_CONST*abs(ϕ))^(1/3) * σ
2525

2626
function flat_circle!(f, pos, r, dim::Int; kwargs...)
2727
if dim == 3
@@ -67,7 +67,7 @@ function plot!(gd::GaussianDisplay{<:NTuple{<:Any, <:AbstractIsotropicGaussian}}
6767
color = gd[:color][]
6868
plotfun = disp == :wire ? wire_sphere! : ( disp == :solid ? solid_sphere! : throw(ArgumentError("Unrecognized display option: `$disp`")))
6969
for g in gauss
70-
plotfun(gd, g.μ, equal_volume_radius(g.σ, g.ϕ); color=color)
70+
plotfun(gd, g.μ, g.σ; color=color)
7171
end
7272
return gd
7373
end

src/gogma/align.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
function gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; kwargs...)
2-
pσ, pϕ = pairwise_consts(gmmx,gmmy)
1+
function gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...)
2+
pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions)
33
boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ)
44
localfun(x,y,block) = local_align(x,y,block,pσ,pϕ)
55
return branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...)
66
end
7-
function rot_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; kwargs...)
8-
pσ, pϕ = pairwise_consts(gmmx,gmmy)
7+
function rot_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...)
8+
pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions)
99
boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ)
1010
localfun(x,y,block) = local_align(x,y,block,pσ,pϕ)
1111
rot_branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...)
1212
end
13-
function trl_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; kwargs...)
14-
pσ, pϕ = pairwise_consts(gmmx,gmmy)
13+
function trl_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...)
14+
pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions)
1515
boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ)
1616
localfun(x,y,block) = local_align(x,y,block,pσ,pϕ)
1717
trl_branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...)

src/gogma/bounds.jl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ loose_distance_bounds(x::AbstractGaussian, y::AbstractGaussian, args...) = loose
22
tight_distance_bounds(x::AbstractGaussian, y::AbstractGaussian, args...) = tight_distance_bounds(x.μ, y.μ, args...)
33

44
# prepare pairwise values for `σx^2 + σy^2` and `ϕx * ϕy` for all gaussians in `gmmx` and `gmmy`
5-
function pairwise_consts(gmmx::AbstractIsotropicGMM, gmmy::AbstractIsotropicGMM)
5+
function pairwise_consts(gmmx::AbstractIsotropicGMM, gmmy::AbstractIsotropicGMM, interactions=nothing)
66
t = promote_type(numbertype(gmmx),numbertype(gmmy))
77
pσ, pϕ = zeros(t, length(gmmx), length(gmmy)), zeros(t, length(gmmx), length(gmmy))
88
for (i,gaussx) in enumerate(gmmx.gaussians)
@@ -14,13 +14,33 @@ function pairwise_consts(gmmx::AbstractIsotropicGMM, gmmy::AbstractIsotropicGMM)
1414
return pσ, pϕ
1515
end
1616

17-
function pairwise_consts(mgmmx::AbstractMultiGMM{N,T,K}, mgmmy::AbstractMultiGMM{N,S,K}) where {N,T,S,K}
18-
t = promote_type(numbertype(mgmmx),numbertype(mgmmy))
19-
mpσ, mpϕ = Dict{K, Matrix{t}}(), Dict{K, Matrix{t}}()
20-
for key in keys(mgmmx.gmms) keys(mgmmy.gmms)
21-
pσ, pϕ = pairwise_consts(mgmmx.gmms[key], mgmmy.gmms[key])
22-
push!(mpσ, Pair(key, pσ))
23-
push!(mpϕ, Pair(key, pϕ))
17+
function pairwise_consts(mgmmx::AbstractMultiGMM{N,T,K}, mgmmy::AbstractMultiGMM{N,S,K}, interactions::Union{Nothing,Dict{K,Dict{K,V}}}=nothing) where {N,T,S,K,V <: Number}
18+
t = promote_type(numbertype(mgmmx),numbertype(mgmmy), isnothing(interactions) ? numbertype(mgmmx) : V)
19+
xkeys = keys(mgmmx.gmms)
20+
ykeys = keys(mgmmy.gmms)
21+
if isnothing(interactions)
22+
interactions = Dict{K, Dict{K, t}}()
23+
for key in xkeys ykeys
24+
interactions[key] = Dict{K, t}(key => one(t))
25+
end
26+
end
27+
mpσ, mpϕ = Dict{K, Dict{K, Matrix{t}}}(), Dict{K, Dict{K,Matrix{t}}}()
28+
for key1 in keys(interactions)
29+
if key1 xkeys
30+
push!(mpσ, key1 => Dict{K, Matrix{t}}())
31+
push!(mpϕ, key1 => Dict{K, Matrix{t}}())
32+
for key2 in keys(interactions[key1])
33+
if key2 ykeys
34+
pσ, pϕ = pairwise_consts(mgmmx.gmms[key1], mgmmy.gmms[key2])
35+
push!(mpσ[key1], key2 => pσ)
36+
push!(mpϕ[key1], key2 => interactions[key1][key2] .* pϕ)
37+
end
38+
end
39+
if isempty(mpσ[key1])
40+
delete!(mpσ, key1)
41+
delete!(mpϕ, key1)
42+
end
43+
end
2444
end
2545
return mpσ, mpϕ
2646
end
@@ -41,7 +61,7 @@ the uncertainty region is assumed to be centered at the origin (i.e. x has alrea
4161
See [Campbell & Peterson, 2016](https://arxiv.org/abs/1603.00150)
4262
"""
4363
function gauss_l2_bounds(x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian, R::RotationVec, T::SVector{3}, σᵣ, σₜ, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; distance_bound_fun = tight_distance_bounds)
44-
(lbdist, ubdist) = distance_bound_fun(R*x.μ, y.μ-T, σᵣ, σₜ)
64+
(lbdist, ubdist) = distance_bound_fun(R*x.μ, y.μ-T, σᵣ, σₜ, w < 0)
4565

4666
# evaluate objective function at each distance to get upper and lower bounds
4767
return -overlap(lbdist^2, s, w), -overlap(ubdist^2, s, w)
@@ -58,7 +78,7 @@ gauss_l2_bounds(x::AbstractGaussian, y::AbstractGaussian, block::SearchRegion, s
5878

5979

6080

61-
function gauss_l2_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, pσ=nothing, pϕ=nothing; kwargs...)
81+
function gauss_l2_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, pσ=nothing, pϕ=nothing, interactions=nothing; kwargs...)
6282
# prepare pairwise widths and weights, if not provided
6383
if isnothing(pσ) || isnothing(pϕ)
6484
pσ, pϕ = pairwise_consts(gmmx, gmmy)
@@ -75,17 +95,19 @@ function gauss_l2_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::Ro
7595
return lb, ub
7696
end
7797

78-
function gauss_l2_bounds(mgmmx::AbstractMultiGMM, mgmmy::AbstractMultiGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, mpσ=nothing, mpϕ=nothing)
98+
function gauss_l2_bounds(mgmmx::AbstractMultiGMM, mgmmy::AbstractMultiGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, mpσ=nothing, mpϕ=nothing, interactions=nothing)
7999
# prepare pairwise widths and weights, if not provided
80100
if isnothing(mpσ) || isnothing(mpϕ)
81-
mpσ, mpϕ = pairwise_consts(mgmmx, mgmmy)
101+
mpσ, mpϕ = pairwise_consts(mgmmx, mgmmy, interactions)
82102
end
83103

84104
# sum bounds for each pair of points
85105
lb = 0.
86106
ub = 0.
87-
for key in keys(mgmmx.gmms) keys(mgmmy.gmms)
88-
lb, ub = (lb, ub) .+ gauss_l2_bounds(mgmmx.gmms[key], mgmmy.gmms[key], R, T, σᵣ, σₜ, mpσ[key], mpϕ[key])
107+
for (key1, intrs) in mpσ
108+
for (key2, pσ) in intrs
109+
lb, ub = (lb, ub) .+ gauss_l2_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2])
110+
end
89111
end
90112
return lb, ub
91113
end

src/gogma/overlap.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,18 @@ end
6161
6262
Calculates the unnormalized overlap between two `AbstractMultiGMM` objects.
6363
"""
64-
function overlap(x::AbstractMultiGMM, y::AbstractMultiGMM, mpσ=nothing, mpϕ=nothing)
64+
function overlap(x::AbstractMultiGMM, y::AbstractMultiGMM, mpσ=nothing, mpϕ=nothing, interactions=nothing)
6565
# prepare pairwise widths and weights, if not provided
6666
if isnothing(mpσ) || isnothing(mpϕ)
67-
mpσ, mpϕ = pairwise_consts(x, y)
67+
mpσ, mpϕ = pairwise_consts(x, y, interactions)
6868
end
6969

7070
# sum overlaps from each keyed pairs of GMM
7171
ovlp = zero(promote_type(numbertype(x),numbertype(y)))
72-
for k in keys(x.gmms) keys(y.gmms)
73-
ovlp += overlap(x.gmms[k], y.gmms[k], mpσ[k], mpϕ[k])
72+
for k1 in keys(mpσ)
73+
for k2 in keys(mpσ[k1])
74+
ovlp += overlap(x.gmms[k1], y.gmms[k2], mpσ[k1][k2], mpϕ[k1][k2])
75+
end
7476
end
7577
return ovlp
7678
end

src/localalign.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,21 @@
11
tformwithparams(X,x) = RotationVec(X[1:3]...)*x + SVector{3}(X[4:6]...)
2+
# function tformwithparams(X,x)
3+
# if sum(abs2, X[1:3]) == 0 # handled for autodiff around 0
4+
# T = eltype(X)
5+
# θ = norm(X[1:3])
6+
# a = θ > 0 ? X[1] / θ : one(T)
7+
# b = θ > 0 ? X[2] / θ : zero(T)
8+
# c = θ > 0 ? X[3] / θ : zero(T)
9+
# R = AngleAxis(θ, a, b, c)
10+
# @show R
11+
# else
12+
# R = RotationVec(X[1:3]...)
13+
# end
14+
# t = SVector{3}(X[4:6]...)
15+
# @show (R*x)[1]
16+
# return R*x + t
17+
# end
18+
219
overlapobj(X,x,y,args...) = -overlap(tformwithparams(X,x), y, args...)
320

421
function distanceobj(X, x, y; correspondence = hungarian_assignment)
@@ -34,6 +51,10 @@ function local_align(x::AbstractModel, y::AbstractModel, block::SearchRegion, ar
3451

3552
# set initial guess at the center of the block
3653
initial_X = center(block)
54+
# if (typeof(block) <: UncertaintyRegion && sum(abs2, initial_X[1:Int(end/2)]) == 0) || (typeof(block) <: RotationRegion && sum(abs2, initial_X) == 0)
55+
# T = eltype(initial_X)
56+
# initial_X = initial_X .+ [eps(T), zeros(T, length(initial_X)-1)...]
57+
# end
3758

3859
# local optimization within the block
3960
f(X) = alignment_objective(X, x, y, block, args...; kwargs...)

test/runtests.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,40 @@ end
140140
@test isapprox(rocs_align(gmmx, gmmy).minimum, -overlap(gmmx,gmmx); atol=1E-12)
141141
end
142142

143+
@testset "MultiGMMs with interactions" begin
144+
tetrahedral = [
145+
[0.,0.,1.],
146+
[sqrt(8/9), 0., -1/3],
147+
[-sqrt(2/9),sqrt(2/3),-1/3],
148+
[-sqrt(2/9),-sqrt(2/3),-1/3]
149+
]
150+
ch_g = IsotropicGaussian(tetrahedral[1], 1.0, 1.0)
151+
s_gs = [IsotropicGaussian(x, 0.5, 1.0) for (i,x) in enumerate(tetrahedral)]
152+
mgmmx = IsotropicMultiGMM(Dict(
153+
:positive => IsotropicGMM([ch_g]),
154+
:steric => IsotropicGMM(s_gs)
155+
))
156+
mgmmy = IsotropicMultiGMM(Dict(
157+
:negative => IsotropicGMM([ch_g]),
158+
:steric => IsotropicGMM(s_gs)
159+
))
160+
interactions = Dict(
161+
:positive => Dict(
162+
:positive => -1.0,
163+
:negative => 1.0,
164+
),
165+
:negative => Dict(
166+
:positive => 1.0,
167+
:negative => -1.0,
168+
),
169+
:steric => Dict(
170+
:steric => -1.0,
171+
),
172+
)
173+
randtform = AffineMap(RotationVec*0.1rand(3)...), SVector{3}(0.1*rand(3)...))
174+
res = gogma_align(randtform(mgmmx), mgmmy; interactions=interactions, maxsplits=5e3, nextblockfun=GMA.randomblock)
175+
end
176+
143177
@testset "GO-ICP and GO-IH run without errors" begin
144178
xpts = [[0.,0.,0.], [3.,0.,0.,], [0.,4.,0.]]
145179
ypts = [[1.,1.,1.], [1.,-2.,1.], [1.,1.,-3.]]

0 commit comments

Comments
 (0)