Skip to content

Commit 1d738d3

Browse files
committed
update rwmh script
1 parent c684d23 commit 1d738d3

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

example/real_data_expt/rwmh/run_rwmh.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ using JLD2
1111
using MixFlow
1212
const MF = MixFlow
1313

14-
include(joinpath(@__DIR__, "../../evaluation.jl"))
14+
include(joinpath(@__DIR__, "../../julia_env/evaluation.jl"))
1515

1616

1717
function run_simulation(
1818
seed, name, flowtype, kernel, T, nchains;
1919
nsample = 128, target_rej_rate = 0.2, track_cost = true, T_tune = T,
20+
save_jld = false
2021
)
2122
Random.seed!(seed)
2223
prob, dims = load_prob_with_ref(name)
@@ -35,11 +36,22 @@ function run_simulation(
3536

3637
# rej_rate, err = rejection_rate(prob, K, T_check)
3738

38-
df = flow_evaluation(seed, name, flowtype, kernel, T, ϵ; nsample = nsample, nchains = nchains)
39+
df, output = flow_evaluation(seed, name, flowtype, kernel, T, ϵ; nsample = nsample, nchains = nchains, track_cost = track_cost)
3940

4041
# add cost tuning
4142
df[!, "cost_tuning"] .= cost_tuning
42-
return df
43+
44+
jld_pth = joinpath(@__DIR__, "results/")
45+
if save_jld
46+
if !isdir(jld_pth)
47+
mkpath(jld_pth)
48+
end
49+
JLD2.save(
50+
joinpath(jld_pth, "rwmh_$(name)_$seed.jld2"),
51+
"output" => output,
52+
)
53+
end
54+
return df, output
4355
end
4456

4557

@@ -60,6 +72,6 @@ end
6072

6173
# df = flow_evaluation(1, name, flowtype, kernel, T_check, ϵ; nsample = 1024, nchains = 30)
6274

63-
# df = run_simulation(
64-
# 1, "SparseRegression", MF.BackwardIRFMixFlow, RWMH, 5000, 30; nsample = 128
75+
# df, output = run_simulation(
76+
# 1, "Banana", MF.EnsembleIRFFlow, RWMH, 5000, 30; target_rej_rate = 0.766, nsample = 128
6577
# )

example/real_data_expt/rwmh/run_rwmh.nf

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ params.dryRun = false
66
params.n_sample = params.dryRun ? 8 : 64
77
params.nrunThreads = 1
88

9-
def julia_env = file("${moduleDir}/../../")
9+
def julia_env = file("${moduleDir}/../../julia_env")
1010
def julia_script = file(moduleDir/'run_rwmh.jl')
1111
// def plot_script = file(moduleDir/'tuning.jl')
1212

1313
def variables = [
1414
seed: 1..32,
15-
target: ["TReg", "Brownian", "Sonar", "SparseRegression"],
16-
flowtype: ["BackwardIRFMixFlow", "DeterministicMixFlow", "EnsembleIRFFlow"],
15+
target: ["TReg", "Brownian", "Sonar", "SparseRegression", "LGCP"],
16+
flowtype: ["BackwardIRFMixFlow", "DeterministicMixFlow", "EnsembleIRFFlow", "IRFMixFlow"],
1717
kernel: ["MF.RWMH"],
1818
nchains: [30],
1919
flow_length: [5000],
@@ -53,8 +53,8 @@ process run_simulation {
5353
nchains = ${config.nchains}
5454
5555
# run simulation
56-
df = run_simulation(seed, name, flowtype, kernel, T, nchains; nsample = ${params.n_sample})
57-
56+
df, output = run_simulation(seed, name, flowtype, kernel, T, nchains; nsample = ${params.n_sample})
57+
5858
# store output
5959
mkdir("${filed(config)}")
6060
CSV.write("${filed(config)}/summary.csv", df)

0 commit comments

Comments
 (0)