@@ -11,12 +11,13 @@ using JLD2
1111using MixFlow
1212const MF = MixFlow
1313
14- include (joinpath (@__DIR__ , " ../../evaluation.jl" ))
14+ include (joinpath (@__DIR__ , " ../../julia_env/ evaluation.jl" ))
1515
1616
1717function run_simulation (
1818 seed, name, flowtype, kernel, T, nchains;
1919 nsample = 128 , target_rej_rate = 0.2 , track_cost = true , T_tune = T,
20+ save_jld = false
2021)
2122 Random. seed! (seed)
2223 prob, dims = load_prob_with_ref (name)
@@ -35,11 +36,22 @@ function run_simulation(
3536
3637 # rej_rate, err = rejection_rate(prob, K, T_check)
3738
38- df = flow_evaluation (seed, name, flowtype, kernel, T, ϵ; nsample = nsample, nchains = nchains)
39+ df, output = flow_evaluation (seed, name, flowtype, kernel, T, ϵ; nsample = nsample, nchains = nchains, track_cost = track_cost )
3940
4041 # add cost tuning
4142 df[! , " cost_tuning" ] .= cost_tuning
42- return df
43+
44+ jld_pth = joinpath (@__DIR__ , " results/" )
45+ if save_jld
46+ if ! isdir (jld_pth)
47+ mkpath (jld_pth)
48+ end
49+ JLD2. save (
50+ joinpath (jld_pth, " rwmh_$(name) _$seed .jld2" ),
51+ " output" => output,
52+ )
53+ end
54+ return df, output
4355end
4456
4557
6072
6173# df = flow_evaluation(1, name, flowtype, kernel, T_check, ϵ; nsample = 1024, nchains = 30)
6274
63- # df = run_simulation(
64- # 1, "SparseRegression ", MF.BackwardIRFMixFlow , RWMH, 5000, 30; nsample = 128
75+ # df, output = run_simulation(
76+ # 1, "Banana ", MF.EnsembleIRFFlow , RWMH, 5000, 30; target_rej_rate = 0.766, nsample = 128
6577# )
0 commit comments