@@ -80,13 +80,29 @@ function _mcsolve_prob_func(prob, i, repeat)
8080 return remake (prob, p = prm)
8181end
8282
83+ # Standard output function
8384function _mcsolve_output_func (sol, i)
8485 resize! (sol. prob. p. jump_times, sol. prob. p. jump_times_which_idx[] - 1 )
8586 resize! (sol. prob. p. jump_which, sol. prob. p. jump_times_which_idx[] - 1 )
86- put! (sol. prob. p. progr_channel, true )
8787 return (sol, false )
8888end
8989
90+ # Output function with progress bar update
91+ function _mcsolve_output_func_progress (sol, i)
92+ next! (sol. prob. p. progr_trajectories)
93+ return _mcsolve_output_func (sol, i)
94+ end
95+
96+ # Output function with distributed channel update for progress bar
97+ function _mcsolve_output_func_distributed (sol, i)
98+ put! (sol. prob. p. progr_channel, true )
99+ return _mcsolve_output_func (sol, i)
100+ end
101+
102+ _mcsolve_dispatch_output_func () = _mcsolve_output_func
103+ _mcsolve_dispatch_output_func (:: ET ) where {ET<: Union{EnsembleSerial,EnsembleThreads} } = _mcsolve_output_func_progress
104+ _mcsolve_dispatch_output_func (:: EnsembleDistributed ) = _mcsolve_output_func_distributed
105+
90106function _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which)
91107 sol_i = sol[:, i]
92108 ! isempty (sol_i. prob. kwargs[:saveat ]) ?
293309 e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
294310 H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
295311 params::NamedTuple=NamedTuple(),
312+ ntraj::Int=1,
313+ ensemble_method=EnsembleThreads(),
296314 jump_callback::TJC=ContinuousLindbladJumpCallback(),
297315 prob_func::Function=_mcsolve_prob_func,
298316 output_func::Function=_mcsolve_output_func,
317+ progress_bar::Union{Val,Bool}=Val(true),
299318 kwargs...)
300319
301320Generates the `EnsembleProblem` of `ODEProblem`s for the ensemble of trajectories of the Monte Carlo wave function time evolution of an open quantum system.
@@ -343,9 +362,12 @@ If the environmental measurements register a quantum jump, the wave function und
343362- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
344363- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
345364- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
365+ - `ntraj::Int`: Number of trajectories to use.
366+ - `ensemble_method`: Ensemble method to use.
346367- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
347368- `prob_func::Function`: Function to use for generating the ODEProblem.
348369- `output_func::Function`: Function to use for generating the output of a single trajectory.
370+ - `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
349371- `kwargs...`: Additional keyword arguments to pass to the solver.
350372
351373# Notes
@@ -369,29 +391,51 @@ function mcsolveEnsembleProblem(
369391 e_ops:: Union{Nothing,AbstractVector,Tuple} = nothing ,
370392 H_t:: Union{Nothing,Function,TimeDependentOperatorSum} = nothing ,
371393 params:: NamedTuple = NamedTuple (),
394+ ntraj:: Int = 1 ,
395+ ensemble_method = EnsembleThreads (),
372396 jump_callback:: TJC = ContinuousLindbladJumpCallback (),
373397 seeds:: Union{Nothing,Vector{Int}} = nothing ,
374398 prob_func:: Function = _mcsolve_prob_func,
375- output_func:: Function = _mcsolve_output_func,
399+ output_func:: Function = _mcsolve_dispatch_output_func (ensemble_method),
400+ progress_bar:: Union{Val,Bool} = Val (true ),
376401 kwargs... ,
377402) where {MT1<: AbstractMatrix ,T2,TJC<: LindbladJumpCallbackType }
378- prob_mc = mcsolveProblem (
379- H,
380- ψ0,
381- tlist,
382- c_ops;
383- alg = alg,
384- e_ops = e_ops,
385- H_t = H_t,
386- params = params,
387- seeds = seeds,
388- jump_callback = jump_callback,
389- kwargs... ,
390- )
403+ progr = ProgressBar (ntraj, enable = getVal (progress_bar))
404+ if ensemble_method isa EnsembleDistributed
405+ progr_channel:: RemoteChannel{Channel{Bool}} = RemoteChannel (() -> Channel {Bool} (1 ))
406+ @async while take! (progr_channel)
407+ next! (progr)
408+ end
409+ params = merge (params, (progr_channel = progr_channel,))
410+ else
411+ params = merge (params, (progr_trajectories = progr,))
412+ end
413+
414+ # Stop the async task if an error occurs
415+ try
416+ prob_mc = mcsolveProblem (
417+ H,
418+ ψ0,
419+ tlist,
420+ c_ops;
421+ alg = alg,
422+ e_ops = e_ops,
423+ H_t = H_t,
424+ params = params,
425+ seeds = seeds,
426+ jump_callback = jump_callback,
427+ kwargs... ,
428+ )
391429
392- ensemble_prob = EnsembleProblem (prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false )
430+ ensemble_prob = EnsembleProblem (prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false )
393431
394- return ensemble_prob
432+ return ensemble_prob
433+ catch e
434+ if ensemble_method isa EnsembleDistributed
435+ put! (progr_channel, false )
436+ end
437+ rethrow ()
438+ end
395439end
396440
397441@doc raw """
408452 ensemble_method = EnsembleThreads(),
409453 jump_callback::TJC = ContinuousLindbladJumpCallback(),
410454 prob_func::Function = _mcsolve_prob_func,
411- output_func::Function = _mcsolve_output_func ,
455+ output_func::Function = _mcsolve_dispatch_output_func(ensemble_method) ,
412456 progress_bar::Union{Val,Bool} = Val(true),
413457 kwargs...,
414458 )
@@ -493,43 +537,34 @@ function mcsolve(
493537 ensemble_method = EnsembleThreads (),
494538 jump_callback:: TJC = ContinuousLindbladJumpCallback (),
495539 prob_func:: Function = _mcsolve_prob_func,
496- output_func:: Function = _mcsolve_output_func ,
540+ output_func:: Function = _mcsolve_dispatch_output_func (ensemble_method) ,
497541 progress_bar:: Union{Val,Bool} = Val (true ),
498542 kwargs... ,
499543) where {MT1<: AbstractMatrix ,T2,TJC<: LindbladJumpCallbackType }
500544 if ! isnothing (seeds) && length (seeds) != ntraj
501545 throw (ArgumentError (" Length of seeds must match ntraj ($ntraj ), but got $(length (seeds)) " ))
502546 end
503547
504- progr = ProgressBar (ntraj, enable = getVal (progress_bar))
505- progr_channel:: RemoteChannel{Channel{Bool}} = RemoteChannel (() -> Channel {Bool} (1 ))
506- @async while take! (progr_channel)
507- next! (progr)
508- end
509-
510- # Stop the async task if an error occurs
511- try
512- ens_prob_mc = mcsolveEnsembleProblem (
513- H,
514- ψ0,
515- tlist,
516- c_ops;
517- alg = alg,
518- e_ops = e_ops,
519- H_t = H_t,
520- params = merge (params, (progr_channel = progr_channel,)),
521- seeds = seeds,
522- jump_callback = jump_callback,
523- prob_func = prob_func,
524- output_func = output_func,
525- kwargs... ,
526- )
548+ ens_prob_mc = mcsolveEnsembleProblem (
549+ H,
550+ ψ0,
551+ tlist,
552+ c_ops;
553+ alg = alg,
554+ e_ops = e_ops,
555+ H_t = H_t,
556+ params = params,
557+ seeds = seeds,
558+ ntraj = ntraj,
559+ ensemble_method = ensemble_method,
560+ jump_callback = jump_callback,
561+ prob_func = prob_func,
562+ output_func = output_func,
563+ progress_bar = progress_bar,
564+ kwargs... ,
565+ )
527566
528- return mcsolve (ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
529- catch e
530- put! (progr_channel, false )
531- rethrow ()
532- end
567+ return mcsolve (ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
533568end
534569
535570function mcsolve (
@@ -538,33 +573,42 @@ function mcsolve(
538573 ntraj:: Int = 1 ,
539574 ensemble_method = EnsembleThreads (),
540575)
541- sol = solve (ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
542-
543- put! (sol[:, 1 ]. prob. p. progr_channel, false )
544-
545- _sol_1 = sol[:, 1 ]
546-
547- expvals_all = Array {ComplexF64} (undef, length (sol), size (_sol_1. prob. p. expvals)... )
548- states =
549- isempty (_sol_1. prob. kwargs[:saveat ]) ? fill (QuantumObject[], length (sol)) :
550- Vector {Vector{QuantumObject}} (undef, length (sol))
551- jump_times = Vector {Vector{Float64}} (undef, length (sol))
552- jump_which = Vector {Vector{Int16}} (undef, length (sol))
553-
554- foreach (i -> _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which), eachindex (sol))
555- expvals = dropdims (sum (expvals_all, dims = 1 ), dims = 1 ) ./ length (sol)
556-
557- return TimeEvolutionMCSol (
558- ntraj,
559- _sol_1. prob. p. times,
560- states,
561- expvals,
562- expvals_all,
563- jump_times,
564- jump_which,
565- sol. converged,
566- _sol_1. alg,
567- _sol_1. prob. kwargs[:abstol ],
568- _sol_1. prob. kwargs[:reltol ],
569- )
576+ try
577+ sol = solve (ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
578+
579+ if ensemble_method isa EnsembleDistributed
580+ put! (sol[:, 1 ]. prob. p. progr_channel, false )
581+ end
582+
583+ _sol_1 = sol[:, 1 ]
584+
585+ expvals_all = Array {ComplexF64} (undef, length (sol), size (_sol_1. prob. p. expvals)... )
586+ states =
587+ isempty (_sol_1. prob. kwargs[:saveat ]) ? fill (QuantumObject[], length (sol)) :
588+ Vector {Vector{QuantumObject}} (undef, length (sol))
589+ jump_times = Vector {Vector{Float64}} (undef, length (sol))
590+ jump_which = Vector {Vector{Int16}} (undef, length (sol))
591+
592+ foreach (i -> _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which), eachindex (sol))
593+ expvals = dropdims (sum (expvals_all, dims = 1 ), dims = 1 ) ./ length (sol)
594+
595+ return TimeEvolutionMCSol (
596+ ntraj,
597+ _sol_1. prob. p. times,
598+ states,
599+ expvals,
600+ expvals_all,
601+ jump_times,
602+ jump_which,
603+ sol. converged,
604+ _sol_1. alg,
605+ _sol_1. prob. kwargs[:abstol ],
606+ _sol_1. prob. kwargs[:reltol ],
607+ )
608+ catch e
609+ if ensemble_method isa EnsembleDistributed
610+ put! (ens_prob_mc. prob. p. progr_channel, false )
611+ end
612+ rethrow ()
613+ end
570614end
0 commit comments