Skip to content

Commit 21ec0b8

Browse files
timholytmcgrath325
andauthored
[RFC] Fix piracy (#51)
* Update to CoordinateTransformations.kabsch `kabsch` has now been implemented in CoordinateTransformations. While new features are ordinarily not breaking changes, for this package it is due to name conflict. This PR resolves the name conflict, and bumps `[compat]`. Closes #49 * Eliminate piracy of CoordinateTransformations This package extends several functions from CoordinateTransformations, but it extends them in ways that do not involve types defined in this package. Hence these methods are piracy. `trans(A::AbstractMatrix)` could conceivably be defined in CoordinateTransformations, so rather than merge this as-is it might be better to consider submitting it there. * Add Aqua tests I've set these up to run only on CI, but if you prefer it's easier to always run them. They just take a while. * Eliminate piracy of CoordinateTransformations This package extends several functions from CoordinateTransformations, but it extends them in ways that do not involve types defined in this package. Hence these methods are piracy. `trans(A::AbstractMatrix)` could conceivably be defined in CoordinateTransformations, so rather than merge this as-is it might be better to consider submitting it there. * Add Aqua tests I've set these up to run only on CI, but if you prefer it's easier to always run them. They just take a while. * Pass Aqua tests * Fix premature version bump * Fix LinearAlgebra compat entry * Switch to build_tform for local ICP methods --------- Co-authored-by: Tom McGrath <[email protected]>
1 parent dda5e86 commit 21ec0b8

File tree

12 files changed

+99
-59
lines changed

12 files changed

+99
-59
lines changed

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GaussianMixtureAlignment"
22
uuid = "f2431ed1-b9c2-4fdb-af1b-a74d6c93b3b3"
33
authors = ["Tom McGrath <[email protected]> and contributors"]
4-
version = "0.2.3"
4+
version = "0.3.0"
55

66
[deps]
77
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
@@ -27,13 +27,16 @@ Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
2727
GaussianMixtureAlignmentMakieExt = "Makie"
2828

2929
[compat]
30+
Aqua = "0.8"
3031
Colors = "0.12, 0.13"
3132
CoordinateTransformations = "0.6.4"
3233
Distances = "0.10"
3334
ForwardDiff = "0.10"
3435
GenericLinearAlgebra = "0.3"
3536
GeometryBasics = "0.4, 0.5"
3637
Hungarian = "0.7"
38+
IntervalSets = "0.7"
39+
LinearAlgebra = "1.10"
3740
Makie = "0.21, 0.22"
3841
MakieCore = "0.6, 0.7, 0.8, 0.9"
3942
MutableConvexHulls = "0.2"
@@ -43,12 +46,14 @@ PairedLinkedLists = "0.2"
4346
Requires = "1.3"
4447
Rotations = "1.4, 1.5, 1.6, 1.7"
4548
StaticArrays = "1.5"
49+
Test = "1.10"
4650
julia = "1.10"
4751

4852
[extras]
53+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4954
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5055
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
5156
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5257

5358
[targets]
54-
test = ["Test", "ForwardDiff", "IntervalSets"]
59+
test = ["Aqua", "Test", "ForwardDiff", "IntervalSets"]

src/branchbound.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ end
7373
# maxevals - the maximum number of objective function evaluations allowed before search termination
7474
# maxstagnant - the maximum number of `Block` splits allowed without improvement before search termination
7575
"""
76-
result = branchbound(x, y; nsplits=2, searchspace=nothing,
76+
result = branchbound(x, y; nsplits=2, searchspace=nothing,
7777
rot=nothing, trl=nothing, blockfun=fullBlock, objfun=alignment_objective,
7878
rtol=0.01, maxblocks=5e8, maxeva ls=Inf, maxstagnant=Inf, threads=false)
7979
@@ -82,12 +82,12 @@ and `y`, using the [GOGMA algorithm](https://arxiv.org/abs/1603.00150).
8282
8383
Returns a `GlobalAlignmentResult` that contains the maximized overlap of the two GMMs (the upperbound on the objective function),
8484
a lower bound on the alignment objective function, an `AffineMap` which aligns `x` with `y`, and information about the
85-
number of evaluations during the alignment procedure.
86-
"""
85+
number of evaluations during the alignment procedure.
86+
"""
8787
function branchbound(xinput::AbstractModel, yinput::AbstractModel;
8888
nsplits=2, searchspace=nothing, blockfun=UncertaintyRegion, R=RotationVec(0.,0.,0.), T=SVector{3}(0.,0.,0.),
89-
nextblockfun=lowestlbblock, centerinputs=false, boundsfun=tight_distance_bounds, localfun=local_align, tformfun=AffineMap,
90-
atol=0.1, rtol=0, maxblocks=5e8, maxsplits=Inf, maxevals=Inf, maxstagnant=Inf, separatesplit=false)
89+
nextblockfun=lowestlbblock, centerinputs=false, boundsfun=tight_distance_bounds, localfun=local_align, tformfun::TF=AffineMap,
90+
atol=0.1, rtol=0, maxblocks=5e8, maxsplits=Inf, maxevals=Inf, maxstagnant=Inf, separatesplit=false) where TF
9191
x = xinput
9292
y = yinput
9393
if isodd(nsplits)
@@ -127,12 +127,12 @@ function branchbound(xinput::AbstractModel, yinput::AbstractModel;
127127
ub, bestloc = localfun(x, y, searchspace)
128128

129129
progress = [(0, ub, bestloc)]
130-
130+
131131
# split cubes until convergence
132132
ndivisions = 0
133133
sinceimprove = 0
134134
evalsperdiv = rot_trl_split ? length(x)*length(y)*2*nsplits^3 : length(x)*length(y)*nsplits^ndims
135-
135+
136136
while !isempty(hull)
137137
if (length(hull) > maxblocks) || (ndivisions*evalsperdiv > maxevals) || (sinceimprove > maxstagnant) || (ndivisions > maxsplits)
138138
break
@@ -150,9 +150,9 @@ function branchbound(xinput::AbstractModel, yinput::AbstractModel;
150150

151151
# if the best solution so far is close enough to the best possible solution, end
152152
if abs((ub - lb)/lb) < rtol || abs(ub-lb) < atol
153-
tform = tformfun(bestloc)
153+
tform = build_tform(tformfun, bestloc)
154154
if centerinputs
155-
tform = centerx_tform tform inv(centery_tform)
155+
tform = centerx_tform tform inv(centery_tform)
156156
end
157157
return GlobalAlignmentResult(x, y, ub, lb, tform, bestloc, ndivisions*evalsperdiv, ndivisions, length(hull), sinceimprove, progress, "optimum within tolerance")
158158
end
@@ -222,17 +222,17 @@ function branchbound(xinput::AbstractModel, yinput::AbstractModel;
222222
end
223223
end
224224
if isempty(hull)
225-
tform = tformfun(bestloc)
225+
tform = build_tform(tformfun, bestloc)
226226
if centerinputs
227227
tform = centerx_tform tform inv(centery_tform)
228228
end
229-
return GlobalAlignmentResult(x, y, ub, lb, tformfun(bestloc), bestloc, ndivisions*evalsperdiv, ndivisions, length(hull), sinceimprove, progress, "priority queue empty")
229+
return GlobalAlignmentResult(x, y, ub, lb, build_tform(tformfun, bestloc), bestloc, ndivisions*evalsperdiv, ndivisions, length(hull), sinceimprove, progress, "priority queue empty")
230230
else
231-
tform = tformfun(bestloc)
231+
tform = build_tform(tformfun, bestloc)
232232
if centerinputs
233233
tform = centerx_tform tform inv(centery_tform)
234234
end
235-
return GlobalAlignmentResult(x, y, ub, lowestlbnode(hull).data[1], tformfun(bestloc), bestloc, ndivisions*evalsperdiv, ndivisions, length(hull), sinceimprove, progress, "terminated early")
235+
return GlobalAlignmentResult(x, y, ub, lowestlbnode(hull).data[1], build_tform(tformfun, bestloc), bestloc, ndivisions*evalsperdiv, ndivisions, length(hull), sinceimprove, progress, "terminated early")
236236
end
237237
end
238238

@@ -296,22 +296,22 @@ function planefit(mgmm::AbstractIsotropicMultiGMM, R)
296296
return planefit(R * ptsmat)
297297
end
298298

299-
function tiv_branchbound( x::AbstractModel,
300-
y::AbstractModel,
301-
tivx::AbstractModel,
302-
tivy::AbstractModel;
303-
boundsfun=tight_distance_bounds,
304-
rot_boundsfun=boundsfun,
305-
trl_boundsfun=boundsfun,
306-
localfun=local_align,
299+
function tiv_branchbound( x::AbstractModel,
300+
y::AbstractModel,
301+
tivx::AbstractModel,
302+
tivy::AbstractModel;
303+
boundsfun=tight_distance_bounds,
304+
rot_boundsfun=boundsfun,
305+
trl_boundsfun=boundsfun,
306+
localfun=local_align,
307307
rot_localfun=localfun,
308308
trl_localfun=localfun,
309309
kwargs...)
310310
t = promote_type(numbertype(x),numbertype(y))
311311
p = t(π)
312312
z = zero(t)
313313
zeroTranslation = SVector{3}(z,z,z)
314-
314+
315315
rot_res = rot_branchbound(tivx, tivy; localfun=rot_localfun, boundsfun=rot_boundsfun, kwargs...)
316316
rotblock = RotationRegion(RotationVec(rot_res.tform_params...), zeroTranslation, p)
317317
rotscore, rotpos = rot_localfun(tivx, tivy, rotblock)
@@ -339,8 +339,8 @@ function tiv_branchbound( x::AbstractModel,
339339
min = trl_res.upperbound
340340
bestpos = (rot_res.tform_params..., trl_res.tform_params...)
341341
end
342-
343-
return TIVAlignmentResult(x, y, min, trl_res.lowerbound, AffineMap(bestpos), bestpos,
342+
343+
return TIVAlignmentResult(x, y, min, trl_res.lowerbound, build_tform(AffineMap, bestpos), bestpos,
344344
rot_res.obj_calls+trl_res.obj_calls, rot_res.num_splits+trl_res.num_splits,
345345
rot_res.num_blocks+trl_res.num_blocks,
346346
rot_res, trl_res)

src/draw.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,16 @@ function plot!(gd::GMMDisplay{<:NTuple{<:Any,<:AbstractIsotropicGMM}})
9696
return gd
9797
end
9898

99-
function plot!(gd::GMMDisplay{<:NTuple{<:Any,<:AbstractIsotropicMultiGMM{N,T,K}}}) where {N,T,K}
99+
@recipe(MultiGMMDisplay, g) do scene
100+
Theme(
101+
display = :wire,
102+
palette = DEFAULT_COLORS,
103+
color = nothing,
104+
label = "",
105+
)
106+
end
107+
108+
function plot!(gd::MultiGMMDisplay{<:NTuple{<:Any,<:AbstractIsotropicMultiGMM{N,T,K}}}) where {N,T,K}
100109
mgmms = [gd[i][] for i=1:length(gd)]
101110
disp = gd[:display][]
102111
color = gd[:color][]

src/gogma/bounds.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,17 @@ function pairwise_consts(gmmx::AbstractIsotropicGMM, gmmy::AbstractIsotropicGMM,
2525
return pσ, pϕ
2626
end
2727

28-
function pairwise_consts(mgmmx::AbstractMultiGMM{N,T,K}, mgmmy::AbstractMultiGMM{N,S,K}, interactions::Union{Nothing,Dict{Tuple{K,K},V}}=nothing) where {N,T,S,K,V <: Number}
29-
t = promote_type(numbertype(mgmmx),numbertype(mgmmy), isnothing(interactions) ? numbertype(mgmmx) : V)
28+
function pairwise_consts(mgmmx::AbstractMultiGMM{N,T,K}, mgmmy::AbstractMultiGMM{N,S,K}, interactions::Nothing=nothing) where {N,T,S,K}
29+
t = promote_type(T, S)
30+
self_interactions = Dict{Tuple{K,K},t}()
31+
for key in keys(mgmmx.gmms) keys(mgmmy.gmms)
32+
self_interactions[(key,key)] = one(t)
33+
end
34+
pairwise_consts(mgmmx, mgmmy, self_interactions)
35+
end
36+
37+
function pairwise_consts(mgmmx::AbstractMultiGMM{N,T,K}, mgmmy::AbstractMultiGMM{N,S,K}, interactions::Dict{Tuple{K,K},V}) where {N,T,S,K,V <: Number}
38+
t = promote_type(T, S, isnothing(interactions) ? T : V)
3039
xkeys = keys(mgmmx.gmms)
3140
ykeys = keys(mgmmy.gmms)
3241
if isnothing(interactions)

src/goicp/icp.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@ function iterate_kabsch(P, Q, wp=ones(size(P,2)), wq=ones(size(Q,2)); iterations
1515

1616
prevscore = score
1717
tform = kabsch_matches(P, Q, matches, wp, wq)
18-
score = squared_deviation(tform(P),Q,matches)
18+
Pt = transform_columns(tform, P)
19+
score = squared_deviation(Pt,Q,matches)
1920
if prevscore < score
2021
matches = prevmatches
2122
break
2223
end
2324

2425
prevmatches = matches
25-
matches = correspondence(tform(P),Q)
26+
matches = correspondence(Pt,Q)
2627
if matches == prevmatches
2728
break
2829
end
@@ -63,7 +64,7 @@ end
6364
function local_matching_alignment(x::AbstractPointSet, y::AbstractPointSet, block::TranslationRegion; matching_fun = iterative_hungarian, kwargs...)
6465
tformedx = block.R*x + block.T
6566
matches = matching_fun(tformedx, y; kwargs...)
66-
tform = kabsch(x, y, matches)
67+
tform = kabsch_matches(x, y, matches)
6768
score = squared_deviation(tform(x), y, matches)
6869
params = (tform.translation...,)
6970
return (score, params)

src/goicp/kabsch.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@ end
1414
kabsch_centered_matches(P::PointSet, Q::PointSet, matches::AbstractVector{<:Tuple{Int,Int}}) = kabsch_centered_matches(P.coords, Q.coords, matches, P.weights, Q.weights);
1515

1616
# transform DxN matrices
17-
function (tform::Translation)(A::AbstractMatrix)
18-
return hcat([tform(A[:,i]) for i=1:size(A,2)]...)
17+
function transform_columns(tform::Translation, A::AbstractMatrix)
18+
return reduce(hcat, [tform(A[:,i]) for i=1:size(A,2)])
1919
end
20+
function transform_columns(tform::AffineMap, A::AbstractMatrix)
21+
l = LinearMap(tform.linear)
22+
t = Translation(tform.translation)
23+
transform_columns(t, l(A))
24+
end
25+
transform_columns(tform::Union{Translation,AffineMap}, P::PointSet) = PointSet(transform_columns(tform, P.coords), P.weights)
2026

2127

2228
function kabsch_matches(P,Q,matches::AbstractVector{<:Tuple{Int,Int}},wp=ones(size(P,2)),wq=ones(size(Q,2)))
@@ -27,7 +33,7 @@ end
2733

2834
kabsch_matches(P::PointSet, Q::PointSet, matches::AbstractVector{<:Tuple{Int,Int}}) = kabsch_matches(P.coords, Q.coords, matches, P.weights, Q.weights);
2935

30-
function kabsch_matches(P::AbstractMultiPointSet{N,T,K}, Q::AbstractMultiPointSet{N,T,K}, matchesdict, wp = weights(P), wq = weights(Q)) where {N,T,K}
36+
function kabsch_matches(P::AbstractMultiPointSet{N,T,K}, Q::AbstractMultiPointSet{N,T,K}, matchesdict::Dict, wp = weights(P), wq = weights(Q)) where {N,T,K}
3137
matchedP, matchedQ = matched_points(P,Q,matchesdict)
3238
w = Vector{T}()
3339

src/goicp/local.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function align_local_points(P, Q; maxevals=1000, tformfun=AffineMap)
1111

1212
# local optimization within the block
1313
function f(X)
14-
tform = tformfun((X...,))
14+
tform = build_tform(tformfun, X)
1515
score = squared_deviation(tform(P), Q)
1616
return score
1717
end
@@ -31,7 +31,7 @@ function iterate_local_alignment(P, Q; correspondence = hungarian_assignment, it
3131
it += 1
3232
matchedP, matchedQ = matched_points(P,Q,matches)
3333
score, tformparams = align_local_points(matchedP, matchedQ; tformfun=tformfun, kwargs...)
34-
tform = tformfun(tformparams)
34+
tform = build_tform(tformfun, tformparams)
3535
prevmatches = matches
3636
matches = correspondence(tform(P), Q)
3737
if matches == prevmatches
@@ -45,7 +45,7 @@ function iterate_local_alignment(P, Q, block; tformfun=AffineMap, kwargs...)
4545
block_tform = AffineMap(block.R, block.T)
4646
tformedP = block_tform(P)
4747
score, opt_tformparams = iterate_local_alignment(tformedP, Q; tformfun=tformfun, kwargs...)
48-
opt_tform = tformfun(opt_tformparams)
48+
opt_tform = build_tform(tformfun, opt_tformparams)
4949
tform = opt_tform block_tform
5050
tform = AffineMap(RotationVec(tform.linear), tform.translation)
5151
tformparams = tformfun === LinearMap ? (tform.linear.sx, tform.linear.sy, tform.linear.sz) :

src/goicp/rmsd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ squared_deviation(P::AbstractMatrix, Q::AbstractMatrix, matches::AbstractVector{
33
squared_deviation(P::AbstractPointSet, Q::AbstractPointSet, matches::AbstractVector{<:Tuple{Int,Int}}) = squared_deviation(P.coords, Q.coords, matches, P.weights, Q.weights)
44
squared_deviation(P::AbstractPointSet, Q::AbstractPointSet) = squared_deviation(P.coords, Q.coords, hungarian_assignment(P.coords,Q.coords), P.weights, Q.weights)
55

6-
function squared_deviation(P::AbstractMultiPointSet{N,T,K}, Q::AbstractMultiPointSet{N,T,K}, matchesdict) where {N,T,K}
6+
function squared_deviation(P::AbstractMultiPointSet{N,T,K}, Q::AbstractMultiPointSet{N,T,K}, matchesdict::Dict) where {N,T,K}
77
sqdev = zero(T)
88
for (key, matches) in matchesdict
99
sqdev += squared_deviation(P.pointsets[key], Q.pointsets[key], matches)

src/rocs/rocsalign.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ struct ROCSAlignmentResult{D,S,T,F<:AbstractAffineMap,X<:AbstractGMM{D,S},Y<:Abs
55
tform::F
66
end
77

8-
"""
8+
"""
99
m = second_moment(gmm, center, dim1, dim2)
1010
1111
Returns the second order moment of `gmm`
1212
"""
13-
function mass_matrix(positions::AbstractMatrix{<:Real},
14-
weights=ones(eltype(positions),size(positions,2)),
15-
widths=zeros(eltype(positions),size(positions,2)),
13+
function mass_matrix(positions::AbstractMatrix{<:Real},
14+
weights=ones(eltype(positions),size(positions,2)),
15+
widths=zeros(eltype(positions),size(positions,2)),
1616
center=centroid(positions,weights))
1717
t = eltype(positions)
1818
npts = size(positions,2)
@@ -46,7 +46,7 @@ is made diagonal, and the GMM center of mass is made the origin.
4646
"""
4747
function inertial_transforms(positions::AbstractMatrix{<:Real},
4848
weights=ones(eltype(positions),size(positions,2)),
49-
widths=zeros(eltype(positions),size(positions,2));
49+
widths=zeros(eltype(positions),size(positions,2));
5050
invert = false)
5151
com = centroid(positions, weights / sum(weights))
5252
massmat = mass_matrix(positions, weights, widths, com)
@@ -60,7 +60,7 @@ function inertial_transforms(positions::AbstractMatrix{<:Real},
6060

6161
# first, align the the eigenvectors to the coordinate system axes
6262
# make sure that a reflection is not performed
63-
if det(evecs) < 0
63+
if det(evecs) < 0
6464
evecs[:,end] = -evecs[:,end]
6565
end
6666

@@ -100,11 +100,11 @@ function rocs_align(gmmmoving::AbstractGMM, gmmfixed::AbstractGMM; kwargs...)
100100

101101
# combine the inertial transform with the subsequent alignment transform for the best result
102102
minoverlap, mindex = findmin([r[1] for r in results])
103-
tformmoving = AffineMap(results[mindex][2]) tformsmoving[mindex]
103+
tformmoving = build_tform(AffineMap, results[mindex][2]) tformsmoving[mindex]
104104

105-
# Apply the inverse of `tformfixed` to the optimized transformation
105+
# Apply the inverse of `tformfixed` to the optimized transformation
106106
alignment_tform = inv(tformfixed) tformmoving
107-
107+
108108
# return the result
109109
return ROCSAlignmentResult(gmmmoving, gmmfixed, minoverlap, alignment_tform)
110110
end

src/tforms.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ import CoordinateTransformations.AffineMap
22
import CoordinateTransformations.LinearMap
33
import CoordinateTransformations.Translation
44

5-
AffineMap(params::NTuple{6}) = AffineMap(RotationVec(params[1:3]...), SVector{3}(params[4:6]...))
6-
LinearMap(params::NTuple{3}) = LinearMap(RotationVec(params...))
7-
Translation(params::NTuple{3}) = Translation(SVector{3}(params))
5+
build_tform(::Type{AffineMap}, params::NTuple{6}) = AffineMap(RotationVec(params[1:3]...), SVector{3}(params[4:6]...))
6+
build_tform(::Type{LinearMap}, params::NTuple{3}) = LinearMap(RotationVec(params...))
7+
build_tform(::Type{Translation}, params::NTuple{3}) = Translation(SVector{3}(params))
88

99
function affinemap_to_params(tform::AffineMap)
1010
R = RotationVec(tform.linear)

0 commit comments

Comments
 (0)