From f2a30857dea2a6e934f10dc1b879863e187658f4 Mon Sep 17 00:00:00 2001 From: Tom McGrath Date: Wed, 5 Jun 2024 14:21:15 -0500 Subject: [PATCH] allow user supplied overlap functions --- Project.toml | 2 +- src/GaussianMixtureAlignment.jl | 4 +- src/gogma/align.jl | 28 +++++++------- src/gogma/bounds.jl | 57 ++++++++++++++++++++--------- src/gogma/overlap.jl | 52 +++++++++++--------------- src/localalign.jl | 6 +-- test/runtests.jl | 65 +++++++++++++++++++++++++++------ 7 files changed, 136 insertions(+), 78 deletions(-) diff --git a/Project.toml b/Project.toml index 5bd4bb1..229a12e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GaussianMixtureAlignment" uuid = "f2431ed1-b9c2-4fdb-af1b-a74d6c93b3b3" authors = ["Tom McGrath and contributors"] -version = "0.2.2" +version = "0.2.3" [deps] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" diff --git a/src/GaussianMixtureAlignment.jl b/src/GaussianMixtureAlignment.jl index 3afb523..0893743 100644 --- a/src/GaussianMixtureAlignment.jl +++ b/src/GaussianMixtureAlignment.jl @@ -38,7 +38,9 @@ using Colors export AbstractGaussian, AbstractGMM export IsotropicGaussian, IsotropicGMM, IsotropicMultiGMM -export overlap, force!, gogma_align, rot_gogma_align, trl_gogma_align, tiv_gogma_align +export overlap, generic_overlap, gaussian_overlap, force! +export lowestlbblock, randomblock +export gogma_align, rot_gogma_align, trl_gogma_align, tiv_gogma_align export rocs_align export PointSet, MultiPointSet export kabsch, icp, iterative_hungarian, goicp_align, goih_align, tiv_goicp_align, tiv_goih_align diff --git a/src/gogma/align.jl b/src/gogma/align.jl index 20abd7f..7f158fb 100644 --- a/src/gogma/align.jl +++ b/src/gogma/align.jl @@ -1,28 +1,28 @@ -function gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...) +function gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, objective=gaussian_overlap, kwargs...) pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions) - boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ) - localfun(x,y,block) = local_align(x,y,block,pσ,pϕ) + boundsfun(x,y,block) = generic_bounds(x,y,block,pσ,pϕ; objective=objective) + localfun(x,y,block) = local_align(x,y,block,pσ,pϕ; objective=objective) return branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...) end -function rot_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...) +function rot_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, objective=gaussian_overlap, kwargs...) pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions) - boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ) - localfun(x,y,block) = local_align(x,y,block,pσ,pϕ) + boundsfun(x,y,block) = generic_bounds(x,y,block,pσ,pϕ; objective=objective) + localfun(x,y,block) = local_align(x,y,block,pσ,pϕ; objective=objective) rot_branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...) end -function trl_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...) +function trl_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, objective=gaussian_overlap, kwargs...) pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions) - boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ) - localfun(x,y,block) = local_align(x,y,block,pσ,pϕ) + boundsfun(x,y,block) = generic_bounds(x,y,block,pσ,pϕ; objective=objective) + localfun(x,y,block) = local_align(x,y,block,pσ,pϕ; objective=objective) trl_branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...) end -function tiv_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM, cx=Inf, cy=Inf; kwargs...) +function tiv_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM, cx=Inf, cy=Inf; objective=gaussian_overlap, kwargs...) tivgmmx, tivgmmy = tivgmm(gmmx, cx), tivgmm(gmmy, cy) pσ, pϕ = pairwise_consts(gmmx,gmmy) tivpσ, tivpϕ = pairwise_consts(tivgmmx,tivgmmy) - boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ) - rot_boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,tivpσ,tivpϕ) - localfun(x,y,block) = local_align(x,y,block,pσ,pϕ) - rot_localfun(x,y,block) = local_align(x,y,block,tivpσ,tivpϕ) + boundsfun(x,y,block) = generic_bounds(x,y,block,pσ,pϕ; objective=objective) + rot_boundsfun(x,y,block) = generic_bounds(x,y,block,tivpσ,tivpϕ; objective=objective) + localfun(x,y,block) = local_align(x,y,block,pσ,pϕ; objective=objective) + rot_localfun(x,y,block) = local_align(x,y,block,tivpσ,tivpϕ; objective=objective) tiv_branchbound(gmmx, gmmy, tivgmm(gmmx, cx), tivgmm(gmmy, cy); boundsfun=boundsfun, rot_boundsfun=rot_boundsfun, localfun=localfun, rot_localfun=rot_localfun, kwargs...) end \ No newline at end of file diff --git a/src/gogma/bounds.jl b/src/gogma/bounds.jl index 118a8d2..b6e12fd 100644 --- a/src/gogma/bounds.jl +++ b/src/gogma/bounds.jl @@ -63,10 +63,10 @@ end """ - lowerbound, upperbound = gauss_l2_bounds(x::Union{IsotropicGaussian, AbstractGMM}, y::Union{IsotropicGaussian, AbstractGMM}, σᵣ, σₜ) - lowerbound, upperbound = gauss_l2_bounds(x, y, R::RotationVec, T::SVector{3}, σᵣ, σₜ) + lowerbound, upperbound = generic_bounds(x::Union{IsotropicGaussian, AbstractGMM}, y::Union{IsotropicGaussian, AbstractGMM}, σᵣ, σₜ; objective = overlap) + lowerbound, upperbound = generic_bounds(x, y, R::RotationVec, T::SVector{3}, σᵣ, σₜ; objective = gaussian_overlap) -Finds the bounds for overlap between two isotropic Gaussian distributions, two isotropic GMMs, or `two sets of +Finds the bounds for the specified `objective` function between two isotropic Gaussian distributions, two isotropic GMMs, or two sets of labeled isotropic GMMs for a particular region in 6-dimensional rigid rotation space, defined by `R`, `T`, `σᵣ` and `σₜ`. `R` and `T` represent the rotation and translation, respectively, that are at the center of the uncertainty region. If they are not provided, @@ -74,27 +74,29 @@ the uncertainty region is assumed to be centered at the origin (i.e. x has alrea `σᵣ` and `σₜ` represent the sizes of the rotation and translation uncertainty regions. +The `objective` should be a function that takes the squared distance between the means of two `IsotropicGaussian`s, the sum of their variances, and the product of their amplitudes. + See [Campbell & Peterson, 2016](https://arxiv.org/abs/1603.00150) """ -function gauss_l2_bounds(x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian, R::RotationVec, T::SVector{3}, σᵣ, σₜ, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; distance_bound_fun = tight_distance_bounds) +function generic_bounds(x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian, R::RotationVec, T::SVector{3}, σᵣ, σₜ, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; distance_bound_fun = tight_distance_bounds, objective = gauss_l2_bounds, kwargs...) (lbdist, ubdist) = distance_bound_fun(R*x.μ, y.μ-T, σᵣ, σₜ, w < 0) # evaluate objective function at each distance to get upper and lower bounds - return -overlap(lbdist^2, s, w), -overlap(ubdist^2, s, w) + return -objective(lbdist^2, s, w; kwargs...), -objective(ubdist^2, s, w; kwargs...) end # gauss_l2_bounds(x::AbstractGaussian, y::AbstractGaussian, R::RotationVec, T::SVector{3}, σᵣ, σₜ, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... # ) = gauss_l2_bounds(R*x, y-T, σᵣ, σₜ, tform.translation, s, w; kwargs...) -gauss_l2_bounds(x::AbstractGaussian, y::AbstractGaussian, block::UncertaintyRegion, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... - ) = gauss_l2_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, s, w; kwargs...) +generic_bounds(x::AbstractGaussian, y::AbstractGaussian, block::UncertaintyRegion, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... + ) = generic_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, s, w; kwargs...) -gauss_l2_bounds(x::AbstractGaussian, y::AbstractGaussian, block::SearchRegion, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... - ) = gauss_l2_bounds(x, y, UncertaintyRegion(block), s, w; kwargs...) +generic_bounds(x::AbstractGaussian, y::AbstractGaussian, block::SearchRegion, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... + ) = generic_bounds(x, y, UncertaintyRegion(block), s, w; kwargs...) -function gauss_l2_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, pσ=nothing, pϕ=nothing, interactions=nothing; kwargs...) +function generic_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, pσ=nothing, pϕ=nothing, interactions=nothing; kwargs...) # prepare pairwise widths and weights, if not provided if isnothing(pσ) || isnothing(pϕ) pσ, pϕ = pairwise_consts(gmmx, gmmy) @@ -105,24 +107,28 @@ function gauss_l2_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::Ro ub = 0. for (i,x) in enumerate(gmmx.gaussians) for (j,y) in enumerate(gmmy.gaussians) - lb, ub = (lb, ub) .+ gauss_l2_bounds(x, y, R, T, σᵣ, σₜ, pσ[i,j], pϕ[i,j]; kwargs...) + lb, ub = (lb, ub) .+ generic_bounds(x, y, R, T, σᵣ, σₜ, pσ[i,j], pϕ[i,j]; kwargs...) end end return lb, ub end -function gauss_l2_bounds(mgmmx::AbstractMultiGMM, mgmmy::AbstractMultiGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, mpσ=nothing, mpϕ=nothing, interactions=nothing) +function generic_bounds(mgmmx::AbstractMultiGMM, mgmmy::AbstractMultiGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, mpσ=nothing, mpϕ=nothing, interactions=nothing; objective=gaussian_overlap, kwargs...) # prepare pairwise widths and weights, if not provided if isnothing(mpσ) || isnothing(mpϕ) mpσ, mpϕ = pairwise_consts(mgmmx, mgmmy, interactions) end + # allow for different objective functions for each pair of keys + isdict = isa(objective, Dict) + # sum bounds for each pair of points lb = 0. ub = 0. for (key1, intrs) in mpσ for (key2, pσ) in intrs - lb, ub = (lb, ub) .+ gauss_l2_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]) + obj = !isdict ? objective : (haskey(objective, (key1,key2)) ? objective[(key1,key2)] : objective[(key2,key1)]) + lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective = obj, kwargs...) end end return lb, ub @@ -131,8 +137,25 @@ end # gauss_l2_bounds(x::AbstractGMM, y::AbstractGMM, R::RotationVec, T::SVector{3}, args...; kwargs... # ) = gauss_l2_bounds(R*x, y-T, args...; kwargs...) -gauss_l2_bounds(x::AbstractGMM, y::AbstractGMM, block::UncertaintyRegion, args...; kwargs... - ) = gauss_l2_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, args...; kwargs...) +generic_bounds(x::AbstractGMM, y::AbstractGMM, block::UncertaintyRegion, args...; kwargs... + ) = generic_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, args...; kwargs...) + +generic_bounds(x::AbstractGMM, y::AbstractGMM, block::SearchRegion, args...; kwargs... + ) = generic_bounds(x, y, UncertaintyRegion(block), args...; kwargs...) + + +""" + lowerbound, upperbound = gauss_l2_bounds(x::Union{IsotropicGaussian, AbstractGMM}, y::Union{IsotropicGaussian, AbstractGMM}, σᵣ, σₜ) + lowerbound, upperbound = gauss_l2_bounds(x, y, R::RotationVec, T::SVector{3}, σᵣ, σₜ) -gauss_l2_bounds(x::AbstractGMM, y::AbstractGMM, block::SearchRegion, args...; kwargs... - ) = gauss_l2_bounds(x, y, UncertaintyRegion(block), args...; kwargs...) \ No newline at end of file +Finds the bounds for overlap between two isotropic Gaussian distributions, two isotropic GMMs, or two sets of +labeled isotropic GMMs for a particular region in 6-dimensional rigid rotation space, defined by `R`, `T`, `σᵣ` and `σₜ`. + +`R` and `T` represent the rotation and translation, respectively, that are at the center of the uncertainty region. If they are not provided, +the uncertainty region is assumed to be centered at the origin (i.e. x has already been transformed). + +`σᵣ` and `σₜ` represent the sizes of the rotation and translation uncertainty regions. + +See [Campbell & Peterson, 2016](https://arxiv.org/abs/1603.00150) +""" +gauss_l2_bounds(args...; kwargs...) = generic_bounds(args...; objective = gaussian_overlap, kwargs...) \ No newline at end of file diff --git a/src/gogma/overlap.jl b/src/gogma/overlap.jl index d078a90..09e93f6 100644 --- a/src/gogma/overlap.jl +++ b/src/gogma/overlap.jl @@ -4,37 +4,21 @@ Calculates the unnormalized overlap between two Gaussian distributions with width `s`, weight `w', and squared distance `distsq`. """ -function overlap(distsq::Real, s::Real, w::Real) +function gaussian_overlap(distsq::Real, s::Real, w::Real) return w * exp(-distsq / (2*s)) # / (sqrt2pi * sqrt(s))^ndims # Note, the normalization term for the Gaussians is left out, since it is not required that the total "volume" of each Gaussian # is equal to 1 (e.g. satisfying the requirements for a probability distribution) end -""" - ovlp = overlap(dist, σx, σy, ϕx, ϕy) - -Calculates the unnormalized overlap between two Gaussian distributions with variances -`σx` and `σy`, weights `ϕx` and `ϕy`, and means separated by distance `dist`. -""" -function overlap(dist::Real, σx::Real, σy::Real, ϕx::Real, ϕy::Real) - return overlap(dist^2, σx^2 + σy^2, ϕx*ϕy) +function generic_overlap(dist::Real, σx::Real, σy::Real, ϕx::Real, ϕy::Real; objective=gaussian_overlap, kwargs...) + return objective(dist^2, σx^2 + σy^2, ϕx*ϕy; kwargs...) end -""" - ovlp = overlap(x::IsotropicGaussian, y::IsotropicGaussian) - -Calculates the unnormalized overlap between two `IsotropicGaussian` objects. -""" -function overlap(x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian, s=x.σ^2+y.σ^2, w=x.ϕ*y.ϕ) - return overlap(sum(abs2, x.μ.-y.μ), s, w) +function generic_overlap(x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian, s=x.σ^2+y.σ^2, w=x.ϕ*y.ϕ; objective=gaussian_overlap, kwargs...) + return objective(sum(abs2, x.μ.-y.μ), s, w; kwargs...) end -""" - ovlp = overlap(x::AbstractSingleGMM, y::AbstractSingleGMM) - -Calculates the unnormalized overlap between two `AbstractSingleGMM` objects. -""" -function overlap(x::AbstractSingleGMM, y::AbstractSingleGMM, pσ=nothing, pϕ=nothing) +function generic_overlap(x::AbstractSingleGMM, y::AbstractSingleGMM, pσ=nothing, pϕ=nothing; kwargs...) # prepare pairwise widths and weights, if not provided if isnothing(pσ) && isnothing(pϕ) pσ, pϕ = pairwise_consts(x, y) @@ -44,33 +28,39 @@ function overlap(x::AbstractSingleGMM, y::AbstractSingleGMM, pσ=nothing, pϕ=no ovlp = zero(promote_type(numbertype(x),numbertype(y))) for (i,gx) in enumerate(x.gaussians) for (j,gy) in enumerate(y.gaussians) - ovlp += overlap(gx, gy, pσ[i,j], pϕ[i,j]) + ovlp += generic_overlap(gx, gy, pσ[i,j], pϕ[i,j]; kwargs...) end end return ovlp end -""" - ovlp = overlap(x::AbstractMultiGMM, y::AbstractMultiGMM) - -Calculates the unnormalized overlap between two `AbstractMultiGMM` objects. -""" -function overlap(x::AbstractMultiGMM, y::AbstractMultiGMM, mpσ=nothing, mpϕ=nothing, interactions=nothing) +function generic_overlap(x::AbstractMultiGMM, y::AbstractMultiGMM, mpσ=nothing, mpϕ=nothing, interactions=nothing; objective=gaussian_overlap, kwargs...) # prepare pairwise widths and weights, if not provided if isnothing(mpσ) && isnothing(mpϕ) mpσ, mpϕ = pairwise_consts(x, y, interactions) end + + isdict = isa(objective, Dict) # sum overlaps from each keyed pairs of GMM ovlp = zero(promote_type(numbertype(x),numbertype(y))) for k1 in keys(mpσ) for k2 in keys(mpσ[k1]) - ovlp += overlap(x.gmms[k1], y.gmms[k2], mpσ[k1][k2], mpϕ[k1][k2]) + obj = !isdict ? objective : (haskey(objective, (k1,k2)) ? objective[(k1,k2)] : objective[(k2,k1)]) + ovlp += generic_overlap(x.gmms[k1], y.gmms[k2], mpσ[k1][k2], mpϕ[k1][k2]; objective = obj, kwargs...) end end return ovlp end +""" + ovlp = overlap(x::G, y::G) where G<:Union{AbstractGaussian, AbstractGMM} + +Calculates the unnormalized overlap between two Gaussians or GMMs. +""" +overlap(args...; kwargs...) = generic_overlap(args...; objective=gaussian_overlap, kwargs...) + + """ l2dist = distance(x, y) @@ -94,7 +84,7 @@ end function force!(f::AbstractVector, x::AbstractVector, y::AbstractVector, s::Real, w::Real) Δ = y - x - f .+= Δ / s * overlap(sum(abs2, Δ), s, w) + f .+= Δ / s * gaussian_overlap(sum(abs2, Δ), s, w) end function force!(f::AbstractVector, x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian, diff --git a/src/localalign.jl b/src/localalign.jl index ca66c97..a29d7ea 100644 --- a/src/localalign.jl +++ b/src/localalign.jl @@ -16,15 +16,15 @@ tformwithparams(X,x) = RotationVec(X[1:3]...)*x + SVector{3}(X[4:6]...) # return R*x + t # end -overlapobj(X,x,y,args...) = -overlap(tformwithparams(X,x), y, args...) +overlapobj(X,x,y,args...; kwargs...) = -generic_overlap(tformwithparams(X,x), y, args...; kwargs...) function distanceobj(X, x, y; correspondence = hungarian_assignment) tformedx = tformwithparams(X,x) return squared_deviation(tformedx, y, correspondence(tformedx,y)) end -function alignment_objective(X, x::AbstractModel, y::AbstractModel, args...; objfun=overlapobj) - return objfun(X,x,y,args...) +function alignment_objective(X, x::AbstractModel, y::AbstractModel, args...; objfun=overlapobj, kwargs...) + return objfun(X,x,y,args...; kwargs...) end # alignment objective for a rigid transformation diff --git a/test/runtests.jl b/test/runtests.jl index a935049..098df38 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,29 +43,29 @@ const GMA = GaussianMixtureAlignment # rotation distances, no translation # anti-aligned (no rotation) and aligned (180 degree rotation) lb, ub = gauss_l2_bounds(x,y,RotationRegion(0.)) - @test lb ≈ -GMA.overlap(7^2,2*σ^2,ϕ*ϕ) atol=1e-16 - @test ub ≈ -GMA.overlap(7^2,2*σ^2,ϕ*ϕ) + @test lb ≈ -gaussian_overlap(7^2,2*σ^2,ϕ*ϕ) atol=1e-16 + @test ub ≈ -gaussian_overlap(7^2,2*σ^2,ϕ*ϕ) lb, ub = gauss_l2_bounds(x,y,RotationRegion(RotationVec(0.,0.,π),SVector{3}(0.,0.,0.),0.)) - @test lb ≈ ub ≈ -GMA.overlap(1,2*σ^2,ϕ*ϕ) + @test lb ≈ ub ≈ -gaussian_overlap(1,2*σ^2,ϕ*ϕ) # region with closest alignment at 90 degree rotation lb = gauss_l2_bounds(x,y,RotationRegion(π/2/sqrt3))[1] - @test lb ≈ -GMA.overlap(5^2,2*σ^2,ϕ*ϕ) + @test lb ≈ -gaussian_overlap(5^2,2*σ^2,ϕ*ϕ) lb = gauss_l2_bounds(x,y,RotationRegion(RotationVec(0,0,π/4),SVector{3}(0.,0.,0.),π/4/(sqrt3)))[1] - @test lb ≈ -GMA.overlap(5^2,2*σ^2,ϕ*ϕ) + @test lb ≈ -gaussian_overlap(5^2,2*σ^2,ϕ*ϕ) # translation distance, no rotation # translation region centered at origin lb, ub = gauss_l2_bounds(x,y,TranslationRegion(1/sqrt3)) - @test lb ≈ -GMA.overlap(6^2,2*σ^2,ϕ*ϕ) - @test ub ≈ -GMA.overlap(7^2,2*σ^2,ϕ*ϕ) + @test lb ≈ -gaussian_overlap(6^2,2*σ^2,ϕ*ϕ) + @test ub ≈ -gaussian_overlap(7^2,2*σ^2,ϕ*ϕ) # centered with translation of 1 in +x lb, ub = gauss_l2_bounds(x+SVector(1,0,0),y,TranslationRegion(1/sqrt3)) - @test lb ≈ -GMA.overlap(7^2,2*σ^2,ϕ*ϕ) - @test ub ≈ -GMA.overlap(8^2,2*σ^2,ϕ*ϕ) + @test lb ≈ -gaussian_overlap(7^2,2*σ^2,ϕ*ϕ) + @test ub ≈ -gaussian_overlap(8^2,2*σ^2,ϕ*ϕ) # centered with translation of 3 in +y lb, ub = gauss_l2_bounds(x+SVector(0,3,0),y,TranslationRegion(1/sqrt3)) - @test lb ≈ -GMA.overlap((√(58)-1)^2,2*σ^2,ϕ*ϕ) - @test ub ≈ -GMA.overlap(58,2*σ^2,ϕ*ϕ) + @test lb ≈ -gaussian_overlap((√(58)-1)^2,2*σ^2,ϕ*ϕ) + @test ub ≈ -gaussian_overlap(58,2*σ^2,ϕ*ϕ) end @@ -232,6 +232,49 @@ end res = gogma_align(randtform(mgmmx), mgmmy; interactions=interactions, maxsplits=5e3, nextblockfun=GMA.randomblock) end +@testset "User-supplied objective" begin + tetrahedral = [ + [0.,0.,1.], + [sqrt(8/9), 0., -1/3], + [-sqrt(2/9),sqrt(2/3),-1/3], + [-sqrt(2/9),-sqrt(2/3),-1/3] + ] + ch_g = IsotropicGaussian(tetrahedral[1], 1.0, 1.0) + s_gs = [IsotropicGaussian(x, 0.5, 1.0) for (i,x) in enumerate(tetrahedral)] + mgmmx = IsotropicMultiGMM(Dict( + :positive => IsotropicGMM([ch_g]), + :steric => IsotropicGMM(s_gs) + )) + mgmmy = IsotropicMultiGMM(Dict( + :negative => IsotropicGMM([ch_g]), + :steric => IsotropicGMM(s_gs) + )) + interactions = Dict( + (:positive, :negative) => 1.0, + (:positive, :positive) => -1.0, + (:negative, :negative) => -1.0, + (:steric, :steric) => -1.0, + ) + randtform = AffineMap(RotationVec(π*0.1rand(3)...), SVector{3}(0.1*rand(3)...)) + + # allowing some fuzziness in the distance + relaxed_overlap(distsq, s, w) = gaussian_overlap(max(0, distsq - sign(w) * 0.5), s, w) + res1 = gogma_align(randtform(mgmmx), mgmmy; interactions=interactions, maxsplits=5e3, nextblockfun=GMA.randomblock) + res2 = gogma_align(randtform(mgmmx), mgmmy; interactions=interactions, maxsplits=5e3, nextblockfun=GMA.randomblock, objective=relaxed_overlap) + @test res1.upperbound > res2.upperbound + @test res1.upperbound > -generic_overlap(res1.tform(randtform(mgmmx)), mgmmy, nothing, nothing, interactions; objective=relaxed_overlap) + + # using different objective functions for different interactions + objective = Dict( + (:positive, :negative) => gaussian_overlap, + (:positive, :positive) => gaussian_overlap, + (:negative, :negative) => gaussian_overlap, + (:steric, :steric) => relaxed_overlap, + ) + res3 = gogma_align(randtform(mgmmx), mgmmy; interactions=interactions, maxsplits=5e3, nextblockfun=GMA.randomblock, objective=objective) + @test res1.upperbound > res3.upperbound > res2.upperbound +end + @testset "Forces" begin μx = randn(SVector{3,Float64}) μy = randn(SVector{3,Float64})