-
Notifications
You must be signed in to change notification settings - Fork 0
Allow user supplied overlap functions #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| name = "GaussianMixtureAlignment" | ||
| uuid = "f2431ed1-b9c2-4fdb-af1b-a74d6c93b3b3" | ||
| authors = ["Tom McGrath <[email protected]> and contributors"] | ||
| version = "0.2.2" | ||
| version = "0.2.3" | ||
|
|
||
| [deps] | ||
| Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,28 +1,28 @@ | ||
| function gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...) | ||
| function gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, objective=gaussian_overlap, kwargs...) | ||
| pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions) | ||
| boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ) | ||
| localfun(x,y,block) = local_align(x,y,block,pσ,pϕ) | ||
| boundsfun(x,y,block) = generic_bounds(x,y,block,pσ,pϕ; objective=objective) | ||
| localfun(x,y,block) = local_align(x,y,block,pσ,pϕ; objective=objective) | ||
| return branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...) | ||
| end | ||
| function rot_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...) | ||
| function rot_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, objective=gaussian_overlap, kwargs...) | ||
| pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions) | ||
| boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ) | ||
| localfun(x,y,block) = local_align(x,y,block,pσ,pϕ) | ||
| boundsfun(x,y,block) = generic_bounds(x,y,block,pσ,pϕ; objective=objective) | ||
| localfun(x,y,block) = local_align(x,y,block,pσ,pϕ; objective=objective) | ||
| rot_branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...) | ||
| end | ||
| function trl_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, kwargs...) | ||
| function trl_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM; interactions=nothing, objective=gaussian_overlap, kwargs...) | ||
| pσ, pϕ = pairwise_consts(gmmx,gmmy,interactions) | ||
| boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ) | ||
| localfun(x,y,block) = local_align(x,y,block,pσ,pϕ) | ||
| boundsfun(x,y,block) = generic_bounds(x,y,block,pσ,pϕ; objective=objective) | ||
| localfun(x,y,block) = local_align(x,y,block,pσ,pϕ; objective=objective) | ||
| trl_branchbound(gmmx, gmmy; boundsfun=boundsfun, localfun=localfun, kwargs...) | ||
| end | ||
| function tiv_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM, cx=Inf, cy=Inf; kwargs...) | ||
| function tiv_gogma_align(gmmx::AbstractGMM, gmmy::AbstractGMM, cx=Inf, cy=Inf; objective=gaussian_overlap, kwargs...) | ||
| tivgmmx, tivgmmy = tivgmm(gmmx, cx), tivgmm(gmmy, cy) | ||
| pσ, pϕ = pairwise_consts(gmmx,gmmy) | ||
| tivpσ, tivpϕ = pairwise_consts(tivgmmx,tivgmmy) | ||
| boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,pσ,pϕ) | ||
| rot_boundsfun(x,y,block) = gauss_l2_bounds(x,y,block,tivpσ,tivpϕ) | ||
| localfun(x,y,block) = local_align(x,y,block,pσ,pϕ) | ||
| rot_localfun(x,y,block) = local_align(x,y,block,tivpσ,tivpϕ) | ||
| boundsfun(x,y,block) = generic_bounds(x,y,block,pσ,pϕ; objective=objective) | ||
| rot_boundsfun(x,y,block) = generic_bounds(x,y,block,tivpσ,tivpϕ; objective=objective) | ||
| localfun(x,y,block) = local_align(x,y,block,pσ,pϕ; objective=objective) | ||
| rot_localfun(x,y,block) = local_align(x,y,block,tivpσ,tivpϕ; objective=objective) | ||
| tiv_branchbound(gmmx, gmmy, tivgmm(gmmx, cx), tivgmm(gmmy, cy); boundsfun=boundsfun, rot_boundsfun=rot_boundsfun, localfun=localfun, rot_localfun=rot_localfun, kwargs...) | ||
| end |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -63,38 +63,40 @@ end | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| lowerbound, upperbound = gauss_l2_bounds(x::Union{IsotropicGaussian, AbstractGMM}, y::Union{IsotropicGaussian, AbstractGMM}, σᵣ, σₜ) | ||||||||||||||||||||||||||
| lowerbound, upperbound = gauss_l2_bounds(x, y, R::RotationVec, T::SVector{3}, σᵣ, σₜ) | ||||||||||||||||||||||||||
| lowerbound, upperbound = generic_bounds(x::Union{IsotropicGaussian, AbstractGMM}, y::Union{IsotropicGaussian, AbstractGMM}, σᵣ, σₜ; objective = overlap) | ||||||||||||||||||||||||||
| lowerbound, upperbound = generic_bounds(x, y, R::RotationVec, T::SVector{3}, σᵣ, σₜ; objective = gaussian_overlap) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Finds the bounds for overlap between two isotropic Gaussian distributions, two isotropic GMMs, or `two sets of | ||||||||||||||||||||||||||
| Finds the bounds for the specified `objective` function between two isotropic Gaussian distributions, two isotropic GMMs, or two sets of | ||||||||||||||||||||||||||
| labeled isotropic GMMs for a particular region in 6-dimensional rigid rotation space, defined by `R`, `T`, `σᵣ` and `σₜ`. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| `R` and `T` represent the rotation and translation, respectively, that are at the center of the uncertainty region. If they are not provided, | ||||||||||||||||||||||||||
| the uncertainty region is assumed to be centered at the origin (i.e. x has already been transformed). | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| `σᵣ` and `σₜ` represent the sizes of the rotation and translation uncertainty regions. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| The `objective` should be a function that takes the squared distance between the means of two `IsotropicGaussian`s, the sum of their variances, and the product of their amplitudes. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| See [Campbell & Peterson, 2016](https://arxiv.org/abs/1603.00150) | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| function gauss_l2_bounds(x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian, R::RotationVec, T::SVector{3}, σᵣ, σₜ, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; distance_bound_fun = tight_distance_bounds) | ||||||||||||||||||||||||||
| function generic_bounds(x::AbstractIsotropicGaussian, y::AbstractIsotropicGaussian, R::RotationVec, T::SVector{3}, σᵣ, σₜ, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; distance_bound_fun = tight_distance_bounds, objective = gauss_l2_bounds, kwargs...) | ||||||||||||||||||||||||||
| (lbdist, ubdist) = distance_bound_fun(R*x.μ, y.μ-T, σᵣ, σₜ, w < 0) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # evaluate objective function at each distance to get upper and lower bounds | ||||||||||||||||||||||||||
| return -overlap(lbdist^2, s, w), -overlap(ubdist^2, s, w) | ||||||||||||||||||||||||||
| return -objective(lbdist^2, s, w; kwargs...), -objective(ubdist^2, s, w; kwargs...) | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # gauss_l2_bounds(x::AbstractGaussian, y::AbstractGaussian, R::RotationVec, T::SVector{3}, σᵣ, σₜ, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... | ||||||||||||||||||||||||||
| # ) = gauss_l2_bounds(R*x, y-T, σᵣ, σₜ, tform.translation, s, w; kwargs...) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| gauss_l2_bounds(x::AbstractGaussian, y::AbstractGaussian, block::UncertaintyRegion, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... | ||||||||||||||||||||||||||
| ) = gauss_l2_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, s, w; kwargs...) | ||||||||||||||||||||||||||
| generic_bounds(x::AbstractGaussian, y::AbstractGaussian, block::UncertaintyRegion, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... | ||||||||||||||||||||||||||
| ) = generic_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, s, w; kwargs...) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| gauss_l2_bounds(x::AbstractGaussian, y::AbstractGaussian, block::SearchRegion, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... | ||||||||||||||||||||||||||
| ) = gauss_l2_bounds(x, y, UncertaintyRegion(block), s, w; kwargs...) | ||||||||||||||||||||||||||
| generic_bounds(x::AbstractGaussian, y::AbstractGaussian, block::SearchRegion, s=x.σ^2 + y.σ^2, w=x.ϕ*y.ϕ; kwargs... | ||||||||||||||||||||||||||
| ) = generic_bounds(x, y, UncertaintyRegion(block), s, w; kwargs...) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| function gauss_l2_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, pσ=nothing, pϕ=nothing, interactions=nothing; kwargs...) | ||||||||||||||||||||||||||
| function generic_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, pσ=nothing, pϕ=nothing, interactions=nothing; kwargs...) | ||||||||||||||||||||||||||
| # prepare pairwise widths and weights, if not provided | ||||||||||||||||||||||||||
| if isnothing(pσ) || isnothing(pϕ) | ||||||||||||||||||||||||||
| pσ, pϕ = pairwise_consts(gmmx, gmmy) | ||||||||||||||||||||||||||
|
|
@@ -105,24 +107,28 @@ function gauss_l2_bounds(gmmx::AbstractSingleGMM, gmmy::AbstractSingleGMM, R::Ro | |||||||||||||||||||||||||
| ub = 0. | ||||||||||||||||||||||||||
| for (i,x) in enumerate(gmmx.gaussians) | ||||||||||||||||||||||||||
| for (j,y) in enumerate(gmmy.gaussians) | ||||||||||||||||||||||||||
| lb, ub = (lb, ub) .+ gauss_l2_bounds(x, y, R, T, σᵣ, σₜ, pσ[i,j], pϕ[i,j]; kwargs...) | ||||||||||||||||||||||||||
| lb, ub = (lb, ub) .+ generic_bounds(x, y, R, T, σᵣ, σₜ, pσ[i,j], pϕ[i,j]; kwargs...) | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| return lb, ub | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| function gauss_l2_bounds(mgmmx::AbstractMultiGMM, mgmmy::AbstractMultiGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, mpσ=nothing, mpϕ=nothing, interactions=nothing) | ||||||||||||||||||||||||||
| function generic_bounds(mgmmx::AbstractMultiGMM, mgmmy::AbstractMultiGMM, R::RotationVec, T::SVector{3}, σᵣ::Number, σₜ::Number, mpσ=nothing, mpϕ=nothing, interactions=nothing; objective=gaussian_overlap, kwargs...) | ||||||||||||||||||||||||||
| # prepare pairwise widths and weights, if not provided | ||||||||||||||||||||||||||
| if isnothing(mpσ) || isnothing(mpϕ) | ||||||||||||||||||||||||||
| mpσ, mpϕ = pairwise_consts(mgmmx, mgmmy, interactions) | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # allow for different objective functions for each pair of keys | ||||||||||||||||||||||||||
| isdict = isa(objective, Dict) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # sum bounds for each pair of points | ||||||||||||||||||||||||||
| lb = 0. | ||||||||||||||||||||||||||
| ub = 0. | ||||||||||||||||||||||||||
| for (key1, intrs) in mpσ | ||||||||||||||||||||||||||
| for (key2, pσ) in intrs | ||||||||||||||||||||||||||
| lb, ub = (lb, ub) .+ gauss_l2_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]) | ||||||||||||||||||||||||||
| obj = !isdict ? objective : (haskey(objective, (key1,key2)) ? objective[(key1,key2)] : objective[(key2,key1)]) | ||||||||||||||||||||||||||
| lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective = obj, kwargs...) | ||||||||||||||||||||||||||
|
Comment on lines
+130
to
+131
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If profiling reveals this line to be a bottleneck, here's one way that might help a tiny bit:
Suggested change
The important part of this is the The other part of the change is vastly less important, but exploits the fact that haskey(dict, a) ? dict[a] : nothinginvolves looking up the key get(dict, a, nothing)only looks up the key struct NotFound end # a private type for internal use only
const notfound = NotFound()
x = get(dict, key, notfound)
if x !== notfound
...Then there's no way that You probably have enough places in your code that might check both orders of the keys that it might be worth splitting the double- |
||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| return lb, ub | ||||||||||||||||||||||||||
|
|
@@ -131,8 +137,25 @@ end | |||||||||||||||||||||||||
| # gauss_l2_bounds(x::AbstractGMM, y::AbstractGMM, R::RotationVec, T::SVector{3}, args...; kwargs... | ||||||||||||||||||||||||||
| # ) = gauss_l2_bounds(R*x, y-T, args...; kwargs...) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| gauss_l2_bounds(x::AbstractGMM, y::AbstractGMM, block::UncertaintyRegion, args...; kwargs... | ||||||||||||||||||||||||||
| ) = gauss_l2_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, args...; kwargs...) | ||||||||||||||||||||||||||
| generic_bounds(x::AbstractGMM, y::AbstractGMM, block::UncertaintyRegion, args...; kwargs... | ||||||||||||||||||||||||||
| ) = generic_bounds(x, y, block.R, block.T, block.σᵣ, block.σₜ, args...; kwargs...) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| generic_bounds(x::AbstractGMM, y::AbstractGMM, block::SearchRegion, args...; kwargs... | ||||||||||||||||||||||||||
| ) = generic_bounds(x, y, UncertaintyRegion(block), args...; kwargs...) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| lowerbound, upperbound = gauss_l2_bounds(x::Union{IsotropicGaussian, AbstractGMM}, y::Union{IsotropicGaussian, AbstractGMM}, σᵣ, σₜ) | ||||||||||||||||||||||||||
| lowerbound, upperbound = gauss_l2_bounds(x, y, R::RotationVec, T::SVector{3}, σᵣ, σₜ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| gauss_l2_bounds(x::AbstractGMM, y::AbstractGMM, block::SearchRegion, args...; kwargs... | ||||||||||||||||||||||||||
| ) = gauss_l2_bounds(x, y, UncertaintyRegion(block), args...; kwargs...) | ||||||||||||||||||||||||||
| Finds the bounds for overlap between two isotropic Gaussian distributions, two isotropic GMMs, or two sets of | ||||||||||||||||||||||||||
| labeled isotropic GMMs for a particular region in 6-dimensional rigid rotation space, defined by `R`, `T`, `σᵣ` and `σₜ`. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| `R` and `T` represent the rotation and translation, respectively, that are at the center of the uncertainty region. If they are not provided, | ||||||||||||||||||||||||||
| the uncertainty region is assumed to be centered at the origin (i.e. x has already been transformed). | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| `σᵣ` and `σₜ` represent the sizes of the rotation and translation uncertainty regions. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| See [Campbell & Peterson, 2016](https://arxiv.org/abs/1603.00150) | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| gauss_l2_bounds(args...; kwargs...) = generic_bounds(args...; objective = gaussian_overlap, kwargs...) | ||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should spell out the exact arguments. From this I infer that the argument order is
objective(Δμ, σ²sum, ϕprod), but best to be explicit.Also, does
objectiveneed to satisfy certain requirements? E.g., does it have to be monotonic? (Lennard-Jones comes to mind.) Do you need to specialize any methods for your objective function? E.g.,estimate_lower_bound(::typeof(lennardjones), Δμ, σ²sum, ϕprod). If there is an API that the user-supplied function needs to satisfy, it should be spelled out.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right that Lennard-Jones wouldn't work here without some careful thought -- maybe handling the attractive and repulsive terms separately.
The assumption is that the objective needs to be monotonically decreasing with increasing Δμ² (which is absolutely important to be clear about).