Skip to content

Commit 5b44073

Browse files
committed
fix keyword arguments for other time evolution solvers
1 parent e0969d5 commit 5b44073

File tree

4 files changed

+17
-19
lines changed

4 files changed

+17
-19
lines changed

benchmarks/timeevolution.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function benchmark_timeevolution!(SUITE)
5050
ntraj = 100,
5151
e_ops = $e_ops,
5252
progress_bar = Val(false),
53-
ensemble_method = EnsembleSerial(),
53+
ensemblealg = EnsembleSerial(),
5454
)
5555
SUITE["Time Evolution"]["time-independent"]["mcsolve"]["Multithreaded"] = @benchmarkable mcsolve(
5656
$H,
@@ -60,7 +60,7 @@ function benchmark_timeevolution!(SUITE)
6060
ntraj = 100,
6161
e_ops = $e_ops,
6262
progress_bar = Val(false),
63-
ensemble_method = EnsembleThreads(),
63+
ensemblealg = EnsembleThreads(),
6464
)
6565

6666
return nothing

src/time_evolution/time_evolution.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ end
308308
function _ensemble_dispatch_solve(
309309
ens_prob_mc::TimeEvolutionProblem,
310310
alg::Union{<:OrdinaryDiffEqAlgorithm,<:StochasticDiffEqAlgorithm},
311-
ensemble_method::ET,
311+
ensemblealg::ET,
312312
ntraj::Int,
313313
) where {ET<:Union{EnsembleSplitThreads,EnsembleDistributed}}
314314
sol = nothing
@@ -319,7 +319,7 @@ function _ensemble_dispatch_solve(
319319
end
320320

321321
@async begin
322-
sol = solve(ens_prob_mc.prob, alg, ensemble_method, trajectories = ntraj)
322+
sol = solve(ens_prob_mc.prob, alg, ensemblealg, trajectories = ntraj)
323323
put!(ens_prob_mc.kwargs.channel, false)
324324
end
325325
end
@@ -329,10 +329,10 @@ end
329329
function _ensemble_dispatch_solve(
330330
ens_prob_mc::TimeEvolutionProblem,
331331
alg::Union{<:OrdinaryDiffEqAlgorithm,<:StochasticDiffEqAlgorithm},
332-
ensemble_method,
332+
ensemblealg,
333333
ntraj::Int,
334334
)
335-
sol = solve(ens_prob_mc.prob, alg, ensemble_method, trajectories = ntraj)
335+
sol = solve(ens_prob_mc.prob, alg, ensemblealg, trajectories = ntraj)
336336
return sol
337337
end
338338

src/time_evolution/time_evolution_dynamical.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,8 @@ function dsf_mcsolveEnsembleProblem(
607607
dsf_params::NamedTuple = NamedTuple();
608608
e_ops::Function = (op_list, p) -> (),
609609
params::NamedTuple = NamedTuple(),
610-
ntraj::Int = 1,
611-
ensemble_method = EnsembleThreads(),
610+
ntraj::Int = 500,
611+
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
612612
δα_list::Vector{<:Real} = fill(0.2, length(op_list)),
613613
jump_callback::TJC = ContinuousLindbladJumpCallback(),
614614
krylov_dim::Int = min(5, cld(length(ψ0.data), 3)),
@@ -660,7 +660,7 @@ function dsf_mcsolveEnsembleProblem(
660660
e_ops = e_ops₀,
661661
params = params2,
662662
ntraj = ntraj,
663-
ensemble_method = ensemble_method,
663+
ensemblealg = ensemblealg,
664664
jump_callback = jump_callback,
665665
prob_func = _dsf_mcsolve_prob_func,
666666
progress_bar = progress_bar,
@@ -679,8 +679,8 @@ end
679679
e_ops::Function=(op_list,p) -> Vector{TOl}([]),
680680
params::NamedTuple=NamedTuple(),
681681
δα_list::Vector{<:Real}=fill(0.2, length(op_list)),
682-
ntraj::Int=1,
683-
ensemble_method=EnsembleThreads(),
682+
ntraj::Int=500,
683+
ensemblealg::EnsembleAlgorithm=EnsembleThreads(),
684684
jump_callback::LindbladJumpCallbackType=ContinuousLindbladJumpCallback(),
685685
krylov_dim::Int=max(6, min(10, cld(length(ket2dm(ψ0).data), 4))),
686686
progress_bar::Union{Bool,Val} = Val(true)
@@ -704,8 +704,8 @@ function dsf_mcsolve(
704704
e_ops::Function = (op_list, p) -> (),
705705
params::NamedTuple = NamedTuple(),
706706
δα_list::Vector{<:Real} = fill(0.2, length(op_list)),
707-
ntraj::Int = 1,
708-
ensemble_method = EnsembleThreads(),
707+
ntraj::Int = 500,
708+
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
709709
jump_callback::TJC = ContinuousLindbladJumpCallback(),
710710
krylov_dim::Int = min(5, cld(length(ψ0.data), 3)),
711711
progress_bar::Union{Bool,Val} = Val(true),
@@ -723,13 +723,13 @@ function dsf_mcsolve(
723723
e_ops = e_ops,
724724
params = params,
725725
ntraj = ntraj,
726-
ensemble_method = ensemble_method,
726+
ensemblealg = ensemblealg,
727727
δα_list = δα_list,
728728
jump_callback = jump_callback,
729729
krylov_dim = krylov_dim,
730730
progress_bar = progress_bar,
731731
kwargs...,
732732
)
733733

734-
return mcsolve(ens_prob_mc, alg, ntraj, ensemble_method)
734+
return mcsolve(ens_prob_mc, alg, ntraj, ensemblealg)
735735
end

test/core-test/dynamical-shifted-fock.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@
5959
α0_l,
6060
dsf_params,
6161
e_ops = e_ops_dsf,
62-
progress_bar = Val(false),
63-
ntraj = 500,
62+
progress_bar = Val(false)
6463
)
6564
val_ss = abs2(sol0.expect[1, end])
6665
@test sum(abs2.(sol0.expect[1, :] .- sol_dsf_me.expect[1, :])) / (val_ss * length(tlist)) < 0.1
@@ -139,8 +138,7 @@
139138
α0_l,
140139
dsf_params,
141140
e_ops = e_ops_dsf2,
142-
progress_bar = Val(false),
143-
ntraj = 500,
141+
progress_bar = Val(false)
144142
)
145143

146144
val_ss = abs2(sol0.expect[1, end])

0 commit comments

Comments
 (0)