Skip to content

Commit 5298908

Browse files
authored
Compute the "force" between two Gaussians (#43)
This is the gradient of `overlap` with respect to the centroid of the "probe" Gaussian.
1 parent b7e3722 commit 5298908

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
2727
GaussianMixtureAlignmentMakieExt = "Makie"
2828

2929
[compat]
30+
Colors = "0.12"
3031
CoordinateTransformations = "0.6"
3132
Distances = "0.10"
32-
Colors = "0.12"
33+
ForwardDiff = "0.10"
3334
GenericLinearAlgebra = "0.3"
3435
GeometryBasics = "0.4"
3536
Hungarian = "0.7"
@@ -45,8 +46,9 @@ StaticArrays = "1.5"
4546
julia = "1.10"
4647

4748
[extras]
49+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4850
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
4951
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5052

5153
[targets]
52-
test = ["Test", "IntervalSets"]
54+
test = ["Test", "ForwardDiff", "IntervalSets"]

src/GaussianMixtureAlignment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using Colors
3838

3939
export AbstractGaussian, AbstractGMM
4040
export IsotropicGaussian, IsotropicGMM, IsotropicMultiGMM
41-
export overlap, gogma_align, rot_gogma_align, trl_gogma_align, tiv_gogma_align
41+
export overlap, force!, gogma_align, rot_gogma_align, trl_gogma_align, tiv_gogma_align
4242
export rocs_align
4343
export PointSet, MultiPointSet
4444
export kabsch, icp, iterative_hungarian, goicp_align, goih_align, tiv_goicp_align, tiv_goih_align

src/gogma/overlap.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,28 @@ Calculates the tanimoto distance based on Gaussian overlap between two GMMs.
8888
function tanimoto(x::AbstractGMM, y::AbstractGMM)
8989
o = overlap(x,y)
9090
return o / (overlap(x,x) + overlap(y,y) - o)
91-
end
91+
end
92+
93+
## Forces
94+
95+
function force!(f::AbstractVector, x::AbstractVector, y::AbstractVector, s::Real, w::Real)
96+
Δ = y - x
97+
f .+= Δ / s * overlap(sum(abs2, Δ), s, w)
98+
end
99+
100+
function force!(f::AbstractVector, x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian,
101+
s=x.σ^2+y.σ^2, w=x.ϕ*y.ϕ; coef=1)
102+
return force!(f, x.μ, y.μ, s, coef*w)
103+
end
104+
105+
function force!(f::AbstractVector, x::AbstractIsotropicGaussian, y::AbstractIsotropicGMM; kwargs...)
106+
for gy in y.gaussians
107+
force!(f, x, gy; kwargs...)
108+
end
109+
end
110+
111+
function force!(f::AbstractVector, x::AbstractIsotropicGMM, y::AbstractIsotropicGMM; kwargs...)
112+
for gx in x.gaussians
113+
force!(f, gx, y; kwargs...)
114+
end
115+
end

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using LinearAlgebra
55
using StaticArrays
66
using Rotations
77
using CoordinateTransformations
8+
using ForwardDiff
89

910
using GaussianMixtureAlignment: UncertaintyRegion, RotationRegion, TranslationRegion
1011
using GaussianMixtureAlignment: tight_distance_bounds, loose_distance_bounds, gauss_l2_bounds, subranges, sqrt3, UncertaintyRegion, subregions, branchbound, rocs_align, overlap, gogma_align, tiv_gogma_align, tiv_goih_align, overlapobj
@@ -174,6 +175,21 @@ end
174175
res = gogma_align(randtform(mgmmx), mgmmy; interactions=interactions, maxsplits=5e3, nextblockfun=GMA.randomblock)
175176
end
176177

178+
@testset "Forces" begin
179+
μx = randn(SVector{3,Float64})
180+
μy = randn(SVector{3,Float64})
181+
σx = 1 + rand()
182+
σy = 1 + rand()
183+
ϕx = 1 + rand()
184+
ϕy = 1 + rand()
185+
x = IsotropicGaussian(μx, σx, ϕx)
186+
y = IsotropicGaussian(μy, σy, ϕy)
187+
f = zeros(3)
188+
force!(f, x, y)
189+
ovlp(μ) = overlap(IsotropicGaussian(μ, σx, ϕx), y)
190+
@test f ForwardDiff.gradient(ovlp, μx)
191+
end
192+
177193
@testset "GO-ICP and GO-IH run without errors" begin
178194
xpts = [[0.,0.,0.], [3.,0.,0.,], [0.,4.,0.]]
179195
ypts = [[1.,1.,1.], [1.,-2.,1.], [1.,1.,-3.]]

0 commit comments

Comments
 (0)