@@ -11,7 +11,7 @@ function create_neural_spline_flow(name, nlayers)
1111 q0 = JLD2. load (joinpath (@__DIR__ , " ../reference/result/$(name) _mfvi.jld2" ))[" reference" ]
1212
1313 dims = length (q0)
14- hdims = min (dims, 100 )
14+ hdims = min (dims, 64 )
1515 mask_idx1 = 1 : 2 : dims
1616 mask_idx2 = 2 : 2 : dims
1717
@@ -28,7 +28,7 @@ function create_real_nvp(name, nlayers)
2828 q0 = JLD2. load (joinpath (@__DIR__ , " ../reference/result/$(name) _mfvi.jld2" ))[" reference" ]
2929
3030 dims = length (q0)
31- hdims = min (dims, 100 )
31+ hdims = min (dims, 64 )
3232 mask_idx1 = 1 : 2 : dims
3333 mask_idx2 = 2 : 2 : dims
3434
@@ -45,7 +45,11 @@ function run_norm_flow(
4545 nsample_eval:: Int = 128 , save_jld:: Bool = true ,
4646)
4747 Random. seed! (seed)
48- target, _, ad = load_model (name)
48+ target, _, _ = load_model (name)
49+
50+ # mooncake is much faster for large nn
51+ ad = AutoMooncake (; config = Mooncake. Config ())
52+
4953 logp = Base. Fix1 (LogDensityProblems. logdensity, target)
5054
5155 # create flow
@@ -63,25 +67,26 @@ function run_norm_flow(
6367 # stop if nan or inf in training
6468 checkconv (iter, stat, re, θ, st) = _is_nan_or_inf (stat. loss) || (stat. gradient_norm < 1e-3 )
6569
66- time = @elapsed begin
67- flow_trained, stats, _ = train_flow (
68- NormalizingFlows. elbo,
69- flow,
70- logp,
71- batchsize;
72- max_iters= niters,
73- optimiser= Optimisers. Adam (lr),
74- ADbackend= ad,
75- show_progress= show_progress,
76- hasconverged= checkconv,
77- )
70+ time_train = @elapsed begin
71+ flow_trained, stats, _ = train_flow (
72+ NormalizingFlows. elbo,
73+ flow,
74+ logp,
75+ batchsize;
76+ max_iters= niters,
77+ optimiser= Optimisers. Adam (lr),
78+ ADbackend= ad,
79+ show_progress= show_progress,
80+ hasconverged= checkconv,
81+ )
7882 end
7983 @info " Training finished"
8084
8185 # if early stop due to NaN or Inf, return NaN for all
8286 if _is_nan_or_inf (stats[end ]. loss)
8387 println (" Training failed: loss is NaN or Inf" )
8488 return DataFrame (
89+ time = NaN ,
8590 elbo = NaN ,
8691 logZ = NaN ,
8792 ess = NaN ,
@@ -109,7 +114,7 @@ function run_norm_flow(
109114 end
110115
111116 df = DataFrame (
112- time = time ,
117+ time = time_train ,
113118 elbo= el,
114119 logZ= logz,
115120 ess= es,
@@ -120,21 +125,20 @@ end
120125
121126
122127
123- # # target_list = ["TReg", "SparseRegression", "Brownian", "Sonar", "LGCP"]
124-
128+ # target_list = ["TReg", "SparseRegression", "Brownian", "Sonar"]
125129
126130# for name in target_list
127131# @info "Running $name"
128132# df = run_norm_flow(
129- # 1, name, "neural_spline_flow", 3, 1e-4 ;
130- # batchsize=64 , niters=100 , show_progress=true,
133+ # 1, name, "neural_spline_flow", 3, 1e-3 ;
134+ # batchsize=32 , niters=10 , show_progress=true,
131135# nsample_eval=128,
132136# )
133137# end
134138
135139# name = "LGCP"
136140# df = run_norm_flow(
137141# 1, name, "real_nvp", 5, 1e-4;
138- # batchsize=64 , niters=100, show_progress=true,
142+ # batchsize=32 , niters=100, show_progress=true,
139143# nsample_eval=128,
140144# )
0 commit comments