11loose_distance_bounds (x:: AbstractGaussian , y:: AbstractGaussian , args... ) = loose_distance_bounds (x. μ, y. μ, args... )
22tight_distance_bounds (x:: AbstractGaussian , y:: AbstractGaussian , args... ) = tight_distance_bounds (x. μ, y. μ, args... )
33
4+ function validate_interactions (interactions:: Dict{Tuple{K,K},V} ) where {K,V<: Number }
5+ for (k1,k2) in keys (interactions)
6+ if k1 != k2
7+ if haskey (interactions, (k2,k1))
8+ return false
9+ end
10+ end
11+ end
12+ return true
13+ end
14+
415# prepare pairwise values for `σx^2 + σy^2` and `ϕx * ϕy` for all gaussians in `gmmx` and `gmmy`
516function pairwise_consts (gmmx:: AbstractIsotropicGMM , gmmy:: AbstractIsotropicGMM , interactions= nothing )
617 t = promote_type (numbertype (gmmx),numbertype (gmmy))
@@ -14,26 +25,31 @@ function pairwise_consts(gmmx::AbstractIsotropicGMM, gmmy::AbstractIsotropicGMM,
1425 return pσ, pϕ
1526end
1627
17- function pairwise_consts (mgmmx:: AbstractMultiGMM{N,T,K} , mgmmy:: AbstractMultiGMM{N,S,K} , interactions:: Union{Nothing,Dict{K,Dict{K,V} }} = nothing ) where {N,T,S,K,V <: Number }
28+ function pairwise_consts (mgmmx:: AbstractMultiGMM{N,T,K} , mgmmy:: AbstractMultiGMM{N,S,K} , interactions:: Union{Nothing,Dict{Tuple{K,K},V }} = nothing ) where {N,T,S,K,V <: Number }
1829 t = promote_type (numbertype (mgmmx),numbertype (mgmmy), isnothing (interactions) ? numbertype (mgmmx) : V)
1930 xkeys = keys (mgmmx. gmms)
2031 ykeys = keys (mgmmy. gmms)
2132 if isnothing (interactions)
22- interactions = Dict {K, Dict{K, t} } ()
33+ interactions = Dict {Tuple{K,K},t } ()
2334 for key in xkeys ∩ ykeys
24- interactions[key] = Dict {K, t} ( key => one (t) )
35+ interactions[( key, key)] = one (t)
2536 end
37+ else
38+ @assert validate_interactions (interactions) " Interactions must not include redundant key pairs (i.e. (k1,k2) and (k2,k1))"
2639 end
2740 mpσ, mpϕ = Dict {K, Dict{K, Matrix{t}}} (), Dict {K, Dict{K,Matrix{t}}} ()
28- for key1 in keys (interactions)
41+ ukeys = unique (Iterators. flatten (keys (interactions)))
42+ for key1 in ukeys
2943 if key1 ∈ xkeys
3044 push! (mpσ, key1 => Dict {K, Matrix{t}} ())
3145 push! (mpϕ, key1 => Dict {K, Matrix{t}} ())
32- for key2 in keys (interactions[key1])
33- if key2 ∈ ykeys
46+ for key2 in ukeys
47+ keypair = (key1,key2)
48+ keypair = haskey (interactions, keypair) ? keypair : (key2,key1)
49+ if key2 ∈ ykeys && haskey (interactions, keypair)
3450 pσ, pϕ = pairwise_consts (mgmmx. gmms[key1], mgmmy. gmms[key2])
3551 push! (mpσ[key1], key2 => pσ)
36- push! (mpϕ[key1], key2 => interactions[key1][key2 ] .* pϕ)
52+ push! (mpϕ[key1], key2 => interactions[keypair ] .* pϕ)
3753 end
3854 end
3955 if isempty (mpσ[key1])
0 commit comments