Skip to content

Commit 863d7ec

Browse files
Dispatch progress bar method in EnsembleProblems
1 parent cec3e08 commit 863d7ec

File tree

4 files changed

+254
-160
lines changed

4 files changed

+254
-160
lines changed

src/QuantumToolbox.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ import SciMLBase:
2424
ODEProblem,
2525
SDEProblem,
2626
EnsembleProblem,
27+
EnsembleSerial,
2728
EnsembleThreads,
29+
EnsembleDistributed,
2830
FullSpecialize,
2931
CallbackSet,
3032
ContinuousCallback,

src/time_evolution/mcsolve.jl

Lines changed: 120 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,29 @@ function _mcsolve_prob_func(prob, i, repeat)
8080
return remake(prob, p = prm)
8181
end
8282

83+
# Standard output function
8384
function _mcsolve_output_func(sol, i)
8485
resize!(sol.prob.p.jump_times, sol.prob.p.jump_times_which_idx[] - 1)
8586
resize!(sol.prob.p.jump_which, sol.prob.p.jump_times_which_idx[] - 1)
86-
put!(sol.prob.p.progr_channel, true)
8787
return (sol, false)
8888
end
8989

90+
# Output function with progress bar update
91+
function _mcsolve_output_func_progress(sol, i)
92+
next!(sol.prob.p.progr_trajectories)
93+
return _mcsolve_output_func(sol, i)
94+
end
95+
96+
# Output function with distributed channel update for progress bar
97+
function _mcsolve_output_func_distributed(sol, i)
98+
put!(sol.prob.p.progr_channel, true)
99+
return _mcsolve_output_func(sol, i)
100+
end
101+
102+
_mcsolve_dispatch_output_func() = _mcsolve_output_func
103+
_mcsolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _mcsolve_output_func_progress
104+
_mcsolve_dispatch_output_func(::EnsembleDistributed) = _mcsolve_output_func_distributed
105+
90106
function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which)
91107
sol_i = sol[:, i]
92108
!isempty(sol_i.prob.kwargs[:saveat]) ?
@@ -293,9 +309,12 @@ end
293309
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
294310
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
295311
params::NamedTuple=NamedTuple(),
312+
ntraj::Int=1,
313+
ensemble_method=EnsembleThreads(),
296314
jump_callback::TJC=ContinuousLindbladJumpCallback(),
297315
prob_func::Function=_mcsolve_prob_func,
298316
output_func::Function=_mcsolve_output_func,
317+
progress_bar::Union{Val,Bool}=Val(true),
299318
kwargs...)
300319
301320
Generates the `EnsembleProblem` of `ODEProblem`s for the ensemble of trajectories of the Monte Carlo wave function time evolution of an open quantum system.
@@ -343,9 +362,12 @@ If the environmental measurements register a quantum jump, the wave function und
343362
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
344363
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
345364
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
365+
- `ntraj::Int`: Number of trajectories to use.
366+
- `ensemble_method`: Ensemble method to use.
346367
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
347368
- `prob_func::Function`: Function to use for generating the ODEProblem.
348369
- `output_func::Function`: Function to use for generating the output of a single trajectory.
370+
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
349371
- `kwargs...`: Additional keyword arguments to pass to the solver.
350372
351373
# Notes
@@ -369,29 +391,51 @@ function mcsolveEnsembleProblem(
369391
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
370392
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
371393
params::NamedTuple = NamedTuple(),
394+
ntraj::Int = 1,
395+
ensemble_method = EnsembleThreads(),
372396
jump_callback::TJC = ContinuousLindbladJumpCallback(),
373397
seeds::Union{Nothing,Vector{Int}} = nothing,
374398
prob_func::Function = _mcsolve_prob_func,
375-
output_func::Function = _mcsolve_output_func,
399+
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
400+
progress_bar::Union{Val,Bool} = Val(true),
376401
kwargs...,
377402
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
378-
prob_mc = mcsolveProblem(
379-
H,
380-
ψ0,
381-
tlist,
382-
c_ops;
383-
alg = alg,
384-
e_ops = e_ops,
385-
H_t = H_t,
386-
params = params,
387-
seeds = seeds,
388-
jump_callback = jump_callback,
389-
kwargs...,
390-
)
403+
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
404+
if ensemble_method isa EnsembleDistributed
405+
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
406+
@async while take!(progr_channel)
407+
next!(progr)
408+
end
409+
params = merge(params, (progr_channel = progr_channel,))
410+
else
411+
params = merge(params, (progr_trajectories = progr,))
412+
end
413+
414+
# Stop the async task if an error occurs
415+
try
416+
prob_mc = mcsolveProblem(
417+
H,
418+
ψ0,
419+
tlist,
420+
c_ops;
421+
alg = alg,
422+
e_ops = e_ops,
423+
H_t = H_t,
424+
params = params,
425+
seeds = seeds,
426+
jump_callback = jump_callback,
427+
kwargs...,
428+
)
391429

392-
ensemble_prob = EnsembleProblem(prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false)
430+
ensemble_prob = EnsembleProblem(prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false)
393431

394-
return ensemble_prob
432+
return ensemble_prob
433+
catch e
434+
if ensemble_method isa EnsembleDistributed
435+
put!(progr_channel, false)
436+
end
437+
rethrow()
438+
end
395439
end
396440

397441
@doc raw"""
@@ -408,7 +452,7 @@ end
408452
ensemble_method = EnsembleThreads(),
409453
jump_callback::TJC = ContinuousLindbladJumpCallback(),
410454
prob_func::Function = _mcsolve_prob_func,
411-
output_func::Function = _mcsolve_output_func,
455+
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
412456
progress_bar::Union{Val,Bool} = Val(true),
413457
kwargs...,
414458
)
@@ -493,43 +537,34 @@ function mcsolve(
493537
ensemble_method = EnsembleThreads(),
494538
jump_callback::TJC = ContinuousLindbladJumpCallback(),
495539
prob_func::Function = _mcsolve_prob_func,
496-
output_func::Function = _mcsolve_output_func,
540+
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
497541
progress_bar::Union{Val,Bool} = Val(true),
498542
kwargs...,
499543
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
500544
if !isnothing(seeds) && length(seeds) != ntraj
501545
throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))"))
502546
end
503547

504-
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
505-
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
506-
@async while take!(progr_channel)
507-
next!(progr)
508-
end
509-
510-
# Stop the async task if an error occurs
511-
try
512-
ens_prob_mc = mcsolveEnsembleProblem(
513-
H,
514-
ψ0,
515-
tlist,
516-
c_ops;
517-
alg = alg,
518-
e_ops = e_ops,
519-
H_t = H_t,
520-
params = merge(params, (progr_channel = progr_channel,)),
521-
seeds = seeds,
522-
jump_callback = jump_callback,
523-
prob_func = prob_func,
524-
output_func = output_func,
525-
kwargs...,
526-
)
548+
ens_prob_mc = mcsolveEnsembleProblem(
549+
H,
550+
ψ0,
551+
tlist,
552+
c_ops;
553+
alg = alg,
554+
e_ops = e_ops,
555+
H_t = H_t,
556+
params = params,
557+
seeds = seeds,
558+
ntraj = ntraj,
559+
ensemble_method = ensemble_method,
560+
jump_callback = jump_callback,
561+
prob_func = prob_func,
562+
output_func = output_func,
563+
progress_bar = progress_bar,
564+
kwargs...,
565+
)
527566

528-
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
529-
catch e
530-
put!(progr_channel, false)
531-
rethrow()
532-
end
567+
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
533568
end
534569

535570
function mcsolve(
@@ -538,33 +573,42 @@ function mcsolve(
538573
ntraj::Int = 1,
539574
ensemble_method = EnsembleThreads(),
540575
)
541-
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
542-
543-
put!(sol[:, 1].prob.p.progr_channel, false)
544-
545-
_sol_1 = sol[:, 1]
546-
547-
expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
548-
states =
549-
isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) :
550-
Vector{Vector{QuantumObject}}(undef, length(sol))
551-
jump_times = Vector{Vector{Float64}}(undef, length(sol))
552-
jump_which = Vector{Vector{Int16}}(undef, length(sol))
553-
554-
foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol))
555-
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)
556-
557-
return TimeEvolutionMCSol(
558-
ntraj,
559-
_sol_1.prob.p.times,
560-
states,
561-
expvals,
562-
expvals_all,
563-
jump_times,
564-
jump_which,
565-
sol.converged,
566-
_sol_1.alg,
567-
_sol_1.prob.kwargs[:abstol],
568-
_sol_1.prob.kwargs[:reltol],
569-
)
576+
try
577+
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
578+
579+
if ensemble_method isa EnsembleDistributed
580+
put!(sol[:, 1].prob.p.progr_channel, false)
581+
end
582+
583+
_sol_1 = sol[:, 1]
584+
585+
expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
586+
states =
587+
isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) :
588+
Vector{Vector{QuantumObject}}(undef, length(sol))
589+
jump_times = Vector{Vector{Float64}}(undef, length(sol))
590+
jump_which = Vector{Vector{Int16}}(undef, length(sol))
591+
592+
foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol))
593+
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)
594+
595+
return TimeEvolutionMCSol(
596+
ntraj,
597+
_sol_1.prob.p.times,
598+
states,
599+
expvals,
600+
expvals_all,
601+
jump_times,
602+
jump_which,
603+
sol.converged,
604+
_sol_1.alg,
605+
_sol_1.prob.kwargs[:abstol],
606+
_sol_1.prob.kwargs[:reltol],
607+
)
608+
catch e
609+
if ensemble_method isa EnsembleDistributed
610+
put!(ens_prob_mc.prob.p.progr_channel, false)
611+
end
612+
rethrow()
613+
end
570614
end

0 commit comments

Comments
 (0)