Skip to content

Commit 7843de7

Browse files
committed
add mfvi as vi baseline
1 parent 0cd0cab commit 7843de7

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
using Random, Distributions
2+
using LinearAlgebra
3+
using LogDensityProblems, LogDensityProblemsAD
4+
using JLD2
5+
6+
using MixFlow
7+
const MF = MixFlow
8+
9+
10+
include(joinpath(@__DIR__, "../../Model.jl"))
11+
include(joinpath(@__DIR__, "../../julia_env/flowlayer.jl"))
12+
13+
function run_baseline(
14+
seed, name::String, lr;
15+
batchsize::Int = 64, niters::Int= 50_000, show_progress=true,
16+
nsample_eval::Int=1024, save_jld::Bool=true,
17+
)
18+
Random.seed!(seed)
19+
20+
@info "load model $(name)"
21+
target, dims, ad = load_model(name)
22+
23+
@info "learning mfvi for $(name), dims = $(dims)"
24+
dim = LogDensityProblems.dimension(target)
25+
logp = Base.Fix1(LogDensityProblems.logdensity, target)
26+
27+
q₀ = MvNormal(zeros(dim), I)
28+
flow =
29+
Bijectors.transformed(q₀, Bijectors.Shift(zeros(dim)) Bijectors.Scale(ones(dim)))
30+
31+
cb(iter, opt_stats, re, θ) = (sample_per_iter = sample_per_iter, ad = ad)
32+
checkconv(iter, stat, re, θ, st) = _is_nan_or_inf(stat.loss) || (stat.gradient_norm < 1e-3)
33+
34+
time = @elapsed begin
35+
flow_trained, stats, _ = train_flow(
36+
NormalizingFlows.elbo,
37+
flow,
38+
logp,
39+
batchsize;
40+
max_iters=niters,
41+
optimiser=Optimisers.Adam(lr),
42+
ADbackend=ad,
43+
show_progress=show_progress,
44+
hasconverged=checkconv,
45+
callback=cb,
46+
)
47+
end
48+
@info "Training finished"
49+
50+
# if early stop due to NaN or Inf, return NaN for all
51+
if _is_nan_or_inf(stats[end].loss)
52+
println("Training failed: loss is NaN or Inf")
53+
return DataFrame(
54+
time = NaN,
55+
elbo = NaN,
56+
logZ = NaN,
57+
ess = NaN,
58+
)
59+
end
60+
61+
# losses = map(x -> x.loss, stats)
62+
# try and if error happens, return NaN
63+
el, logz, es = flow_sample_eval(logp, flow_trained; nsample = nsample_eval)
64+
65+
# save the trained flow
66+
if save_jld
67+
res_dir = joinpath(@__DIR__, "result/")
68+
69+
if !isdir(res_dir)
70+
mkdir(res_dir)
71+
end
72+
73+
JLD2.save(
74+
joinpath(res_dir, "$(name)_mfvi_$(lr)_$(seed).jld2"),
75+
"flow", flow_trained,
76+
"batchsize", batchsize,
77+
"seed", seed,
78+
)
79+
end
80+
81+
df = DataFrame(
82+
time = time,
83+
elbo=el,
84+
logZ=logz,
85+
ess=es,
86+
)
87+
88+
return df
89+
end
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
include { crossProduct; filed; deliverables } from '../../nf-nest/cross.nf'
2+
include { instantiate; precompile; activate } from '../../nf-nest/pkg.nf'
3+
include { combine_csvs; } from '../../nf-nest/combine.nf'
4+
5+
params.dryRun = false
6+
params.n_sample_eval = params.dryRun ? 8 : 1024
7+
params.nrunThreads = 1
8+
9+
def julia_env = file("${moduleDir}/../../julia_env")
10+
def julia_script = file("${moduleDir}/run_ais.jl")
11+
12+
def variables = [
13+
target: ["Sonar", "Brownian", "TReg", "SparseRegression" ,"LGCP"],
14+
lr: ["1e-3"],
15+
batchsize: [64],
16+
niters: [50000],
17+
seed: 1..10,
18+
]
19+
20+
workflow {
21+
compiled_env = instantiate(julia_env) | precompile
22+
configs = crossProduct(variables, params.dryRun)
23+
combined = run_simulation(compiled_env, configs) | combine_csvs
24+
// plot(compiled_env, plot_script, combined)
25+
final_deliverable(compiled_env, combined)
26+
}
27+
28+
29+
process run_simulation {
30+
debug false
31+
memory { 30.GB * Math.pow(2, task.attempt-1) }
32+
time { 24.hour * Math.pow(2, task.attempt-1) }
33+
cpus 1
34+
errorStrategy { task.attempt < 2 ? 'retry' : 'ignore' }
35+
input:
36+
path julia_env
37+
val config
38+
output:
39+
path "${filed(config)}"
40+
"""
41+
${activate(julia_env,params.nrunThreads)}
42+
43+
include("$julia_script")
44+
45+
# get configurations
46+
seed = ${config.seed}
47+
name = "${config.target}"
48+
niters = ${config.niters}
49+
bs = ${config.batchsize}
50+
lr = ${config.lr}
51+
52+
# run simulation
53+
try
54+
df = run_baseline(
55+
seed, name, lr;
56+
batchsize=bs, niters=niters, show_progress=false,
57+
nsample_eval=${params.n_sample_eval},
58+
save_jld = true
59+
)
60+
catch e
61+
df = DataFrame(
62+
time = NaN,
63+
elbo = NaN,
64+
logZ = NaN,
65+
ess = NaN,
66+
)
67+
end
68+
69+
70+
# store output
71+
mkdir("${filed(config)}")
72+
CSV.write("${filed(config)}/summary.csv", df)
73+
"""
74+
}
75+
76+
77+
process final_deliverable {
78+
input:
79+
path julia_env
80+
path combined_csvs_folder
81+
output:
82+
path combined_csvs_folder
83+
publishDir "${deliverables(workflow, params)}", mode: 'copy', overwrite: true
84+
"""
85+
${activate(julia_env)}
86+
"""
87+
}

0 commit comments

Comments
 (0)