Skip to content

Commit 1166f6f

Browse files
authored
Variable ordering for exact inference. (#153)
* Add ordering code. * Update CI. * Update CI: `actions/cache@v4`. * Julia 1.6 -> 1.8. * Use Julia 1.8 in documentation workflow.
1 parent 6b7ff40 commit 1166f6f

File tree

5 files changed

+63
-9
lines changed

5 files changed

+63
-9
lines changed

.github/workflows/CI.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ jobs:
1616
fail-fast: false
1717
matrix:
1818
version:
19-
- "1" # Earliest
20-
- "1.6" # Latest
19+
- "1" # Latest
20+
- "1.8" # Earliest
2121
os:
2222
- ubuntu-latest
2323
- windows-latest
@@ -26,14 +26,14 @@ jobs:
2626
- x64
2727
steps:
2828
# check out the project and install Julia
29-
- uses: actions/checkout@v2
29+
- uses: actions/checkout@v4
3030
- uses: julia-actions/setup-julia@v1
3131
with:
3232
version: ${{ matrix.version }}
3333
arch: ${{ matrix.arch }}
3434

3535
# using a cache can speed up execution times
36-
- uses: actions/cache@v2
36+
- uses: actions/cache@v4
3737
env:
3838
cache-name: cache-artifacts
3939
with:
@@ -57,4 +57,4 @@ jobs:
5757
with:
5858
file: ./lcov.info
5959
flags: unittests
60-
name: codecov-umbrella
60+
name: codecov-umbrella

.github/workflows/Documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- uses: actions/checkout@v2
1515
- uses: julia-actions/setup-julia@latest
1616
with:
17-
version: '1.6'
17+
version: '1.8'
1818
- name: Install LuaLatex
1919
run: sudo apt-get update && sudo apt-get install texlive-full --fix-missing && sudo apt-get install texlive-latex-extra && sudo mktexlsr && sudo updmap-sys
2020
- name: Install dependencies

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "ba4760a4-c768-5bed-964b-cf806dc591cb"
33
version = "3.4.1"
44

55
[deps]
6+
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
67
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
78
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
89
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -25,14 +26,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2526
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2627

2728
[compat]
29+
CliqueTrees = "0.5.2"
2830
DataFrames = "0.22,1.0,1.1"
2931
DataStructures = "0.11,0.12,0.13,0.14,0.15,0.16,0.17,0.18"
3032
Discretizers = "3.0"
3133
Distributions = "0.17,0.18,0.19,0.20,0.21,0.22,0.23,0.24,0.25"
3234
Documenter = "0.26, 0.27"
3335
GraphPlot = "0.5"
34-
IterTools = "1.3"
3536
Graphs = "1.0"
37+
IterTools = "1.3"
3638
LightXML = "0.8,0.9"
3739
NBInclude = "2.0"
3840
Parameters = "0.10,0.11,0.12"
@@ -41,7 +43,7 @@ Reexport = "0.2, 1.0"
4143
Requires = "1.0.1"
4244
SpecialFunctions = "0.8,0.10,1.0,1.1,1.2,2"
4345
StatsBase = "0.25,0.26,0.27,0.28,0.29,0.30,0.31,0.32,0.33"
44-
julia = "1"
46+
julia = "1.8"
4547

4648
[extras]
4749
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/BayesNets.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include(joinpath("CPDs", "cpds.jl"))
1616
@reexport using BayesNets.CPDs.ProbabilisticGraphicalModels
1717

1818
import Base: *, /, +, -
19+
import CliqueTrees
1920
import DataStructures: PriorityQueue, peek
2021
import BayesNets.CPDs.ProbabilisticGraphicalModels: markov_blanket, is_independent, infer
2122
import StatsBase: sample, Weights

src/Inference/exact.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ function infer(im::ExactInference, inf::InferenceState{BN}) where {BN<:DiscreteB
99
nodes = names(bn)
1010
query = inf.query
1111
evidence = inf.evidence
12-
hidden = setdiff(nodes, vcat(query, keys(evidence)))
12+
13+
# hidden = setdiff(nodes, vcat(query, keys(evidence)))
14+
hidden = elimination_order(bn, query, evidence)
1315

1416
factors = map(n -> Factor(bn, n, evidence), nodes)
1517

@@ -31,3 +33,52 @@ function infer(im::ExactInference, inf::InferenceState{BN}) where {BN<:DiscreteB
3133
end
3234
infer(inf::InferenceState{BN}) where {BN<:DiscreteBayesNet} = infer(ExactInference(), inf)
3335
infer(bn::BN, query::NodeNameUnion; evidence::Assignment=Assignment()) where {BN<:DiscreteBayesNet} = infer(ExactInference(), InferenceState(bn, query, evidence))
36+
37+
function elimination_order(bn::BayesNet, query::AbstractVector, evidence::AbstractDict)
38+
order = Symbol[]
39+
index = Dict{Symbol, Int}()
40+
41+
for v in names(bn)
42+
if !haskey(evidence, v)
43+
push!(order, v)
44+
index[v] = length(order)
45+
end
46+
end
47+
48+
# construct reduced graph
49+
matrix = spzeros(Int, length(order), length(order))
50+
51+
for v in order
52+
i = index[v]
53+
push!(rowvals(matrix), i)
54+
push!(nonzeros(matrix), 1)
55+
56+
for w in children(bn, v)
57+
if haskey(index, w)
58+
j = index[w]
59+
push!(rowvals(matrix), j)
60+
push!(nonzeros(matrix), 1)
61+
end
62+
end
63+
64+
matrix.colptr[i + 1] = length(rowvals(matrix)) + 1
65+
end
66+
67+
# moralize graph
68+
matrix = matrix' * matrix
69+
70+
# make query variables a clique
71+
n = length(query)
72+
clique = Vector{Int}(undef, n)
73+
74+
for j in 1:n
75+
clique[j] = index[query[j]]
76+
end
77+
78+
matrix[clique, clique] .= 1
79+
# alg = CliqueTrees.MF() # minimum fill heuristic
80+
alg = CliqueTrees.MMD() # minimum degree heuristic
81+
# alg = CliqueTrees.MCS() # maximum cardinality search
82+
perm, _ = CliqueTrees.permutation(matrix; alg=CliqueTrees.CompositeRotations(clique, alg))
83+
return order[perm[1:end - n]]
84+
end

0 commit comments

Comments
 (0)