Skip to content

Commit 42ccd37

Browse files
committed
change all default ad for nf to mooncake; much faster for large nn than zygote
1 parent ec0c93b commit 42ccd37

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

example/real_data_expt/normflow/run_nf.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function create_neural_spline_flow(name, nlayers)
1111
q0 = JLD2.load(joinpath(@__DIR__, "../reference/result/$(name)_mfvi.jld2"))["reference"]
1212

1313
dims = length(q0)
14-
hdims = min(dims, 100)
14+
hdims = min(dims, 64)
1515
mask_idx1 = 1:2:dims
1616
mask_idx2 = 2:2:dims
1717

@@ -28,7 +28,7 @@ function create_real_nvp(name, nlayers)
2828
q0 = JLD2.load(joinpath(@__DIR__, "../reference/result/$(name)_mfvi.jld2"))["reference"]
2929

3030
dims = length(q0)
31-
hdims = min(dims, 100)
31+
hdims = min(dims, 64)
3232
mask_idx1 = 1:2:dims
3333
mask_idx2 = 2:2:dims
3434

@@ -45,7 +45,11 @@ function run_norm_flow(
4545
nsample_eval::Int=128, save_jld::Bool=true,
4646
)
4747
Random.seed!(seed)
48-
target, _, ad = load_model(name)
48+
target, _, _ = load_model(name)
49+
50+
# mooncake is much faster for large nn
51+
ad = AutoMooncake(; config = Mooncake.Config())
52+
4953
logp = Base.Fix1(LogDensityProblems.logdensity, target)
5054

5155
# create flow
@@ -63,25 +67,26 @@ function run_norm_flow(
6367
# stop if nan or inf in training
6468
checkconv(iter, stat, re, θ, st) = _is_nan_or_inf(stat.loss) || (stat.gradient_norm < 1e-3)
6569

66-
time = @elapsed begin
67-
flow_trained, stats, _ = train_flow(
68-
NormalizingFlows.elbo,
69-
flow,
70-
logp,
71-
batchsize;
72-
max_iters=niters,
73-
optimiser=Optimisers.Adam(lr),
74-
ADbackend=ad,
75-
show_progress=show_progress,
76-
hasconverged=checkconv,
77-
)
70+
time_train = @elapsed begin
71+
flow_trained, stats, _ = train_flow(
72+
NormalizingFlows.elbo,
73+
flow,
74+
logp,
75+
batchsize;
76+
max_iters=niters,
77+
optimiser=Optimisers.Adam(lr),
78+
ADbackend=ad,
79+
show_progress=show_progress,
80+
hasconverged=checkconv,
81+
)
7882
end
7983
@info "Training finished"
8084

8185
# if early stop due to NaN or Inf, return NaN for all
8286
if _is_nan_or_inf(stats[end].loss)
8387
println("Training failed: loss is NaN or Inf")
8488
return DataFrame(
89+
time = NaN,
8590
elbo = NaN,
8691
logZ = NaN,
8792
ess = NaN,
@@ -109,7 +114,7 @@ function run_norm_flow(
109114
end
110115

111116
df = DataFrame(
112-
time = time,
117+
time = time_train,
113118
elbo=el,
114119
logZ=logz,
115120
ess=es,
@@ -120,21 +125,20 @@ end
120125

121126

122127

123-
# # target_list = ["TReg", "SparseRegression", "Brownian", "Sonar", "LGCP"]
124-
128+
# target_list = ["TReg", "SparseRegression", "Brownian", "Sonar"]
125129

126130
# for name in target_list
127131
# @info "Running $name"
128132
# df = run_norm_flow(
129-
# 1, name, "neural_spline_flow", 3, 1e-4;
130-
# batchsize=64, niters=100, show_progress=true,
133+
# 1, name, "neural_spline_flow", 3, 1e-3;
134+
# batchsize=32, niters=10, show_progress=true,
131135
# nsample_eval=128,
132136
# )
133137
# end
134138

135139
# name = "LGCP"
136140
# df = run_norm_flow(
137141
# 1, name, "real_nvp", 5, 1e-4;
138-
# batchsize=64, niters=100, show_progress=true,
142+
# batchsize=32, niters=100, show_progress=true,
139143
# nsample_eval=128,
140144
# )

0 commit comments

Comments
 (0)