141141
142142function _normalize_state! (u, dims, normalize_states)
143143 getVal (normalize_states) && normalize! (u)
144- return QuantumObject (u, dims = dims)
145- end
146-
147- function _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which, normalize_states, dims)
148- sol_i = sol[:, i]
149- ! isempty (sol_i. prob. kwargs[:saveat ]) ? states[i] = map (u -> _normalize_state! (u, dims, normalize_states), sol_i. u) :
150- nothing
151-
152- copyto! (view (expvals_all, i, :, :), sol_i. prob. p. expvals)
153- jump_times[i] = sol_i. prob. p. mcsolve_params. jump_times
154- return jump_which[i] = sol_i. prob. p. mcsolve_params. jump_which
144+ return QuantumObject (u, type = Ket, dims = dims)
155145end
156146
157147function _generate_mcsolve_kwargs (e_ops, tlist, c_ops, jump_callback, kwargs)
@@ -308,15 +298,15 @@ function mcsolveProblem(
308298 kwargs2 = merge (default_values, kwargs)
309299 kwargs3 = _generate_mcsolve_kwargs (e_ops, tlist, c_ops, jump_callback, kwargs2)
310300
311- cache_mc = similar (ψ0. data)
312- weights_mc = similar (ψ0. data, length (c_ops)) # It should be a Float64 Vector, but we have to keep the same type for all the parameters due to SciMLStructures.jl
301+ cache_mc = similar (ψ0. data, T )
302+ weights_mc = similar (ψ0. data, T, length (c_ops)) # It should be a Float64 Vector, but we have to keep the same type for all the parameters due to SciMLStructures.jl
313303 cumsum_weights_mc = similar (weights_mc)
314304
315- jump_times = similar (ψ0. data, jump_times_which_init_size)
316- jump_which = similar (ψ0. data, jump_times_which_init_size)
305+ jump_times = similar (ψ0. data, T, jump_times_which_init_size)
306+ jump_which = similar (ψ0. data, T, jump_times_which_init_size)
317307 jump_times_which_idx = T[1 ] # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl
318308
319- random_n = similar (ψ0. data, 1 ) # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl.
309+ random_n = similar (ψ0. data, T, 1 ) # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl.
320310 random_n[1 ] = rand (rng)
321311
322312 progr = ProgressBar (length (tlist), enable = false )
@@ -432,24 +422,24 @@ function mcsolveEnsembleProblem(
432422 output_func:: Union{Tuple,Nothing} = nothing ,
433423 kwargs... ,
434424) where {DT1,DT2,TJC<: LindbladJumpCallbackType }
425+ _prob_func = prob_func isa Nothing ? _mcsolve_dispatch_prob_func (rng, ntraj) : prob_func
426+ _output_func =
427+ output_func isa Nothing ? _mcsolve_dispatch_output_func (ensemble_method, progress_bar, ntraj) : output_func
428+
435429 prob_mc = mcsolveProblem (
436430 H,
437431 ψ0,
438432 tlist,
439433 c_ops;
440434 e_ops = e_ops,
441435 params = params,
442- rng = deepcopy ( rng), # By deepcopying, we avoid to also count the initialization of mcsolveProblem
436+ rng = rng,
443437 jump_callback = jump_callback,
444438 kwargs... ,
445439 )
446440
447- _prob_func = prob_func isa Nothing ? _mcsolve_dispatch_prob_func (rng, ntraj) : prob_func
448- _output_func =
449- output_func isa Nothing ? _mcsolve_dispatch_output_func (ensemble_method, progress_bar, ntraj) : output_func
450-
451441 ensemble_prob = TimeEvolutionProblem (
452- EnsembleProblem (prob_mc. prob, prob_func = _prob_func, output_func = _output_func[1 ], safetycopy = false ),
442+ EnsembleProblem (prob_mc. prob, prob_func = _prob_func, output_func = _output_func[1 ], safetycopy = true ),
453443 prob_mc. times,
454444 prob_mc. dims,
455445 (progr = _output_func[2 ], channel = _output_func[3 ]),
@@ -558,7 +548,7 @@ function mcsolve(
558548 jump_callback:: TJC = ContinuousLindbladJumpCallback (),
559549 progress_bar:: Union{Val,Bool} = Val (true ),
560550 prob_func:: Union{Function,Nothing} = nothing ,
561- output_func:: Union{Function ,Nothing} = nothing ,
551+ output_func:: Union{Tuple ,Nothing} = nothing ,
562552 normalize_states:: Union{Val,Bool} = Val (true ),
563553 kwargs... ,
564554) where {DT1,DT2,TJC<: LindbladJumpCallbackType }
@@ -580,53 +570,59 @@ function mcsolve(
580570 kwargs... ,
581571 )
582572
583- return mcsolve (
584- ens_prob_mc;
585- alg = alg,
586- ntraj = ntraj,
587- ensemble_method = ensemble_method,
588- normalize_states = normalize_states,
589- )
573+ return mcsolve (ens_prob_mc, alg, ntraj, ensemble_method, normalize_states)
574+ end
575+
576+ function _mcsolve_solve_ens (
577+ ens_prob_mc:: TimeEvolutionProblem ,
578+ alg:: OrdinaryDiffEqAlgorithm ,
579+ ensemble_method:: ET ,
580+ ntraj:: Int ,
581+ ) where {ET<: Union{EnsembleSplitThreads,EnsembleDistributed} }
582+ sol = nothing
583+
584+ @sync begin
585+ @async while take! (ens_prob_mc. kwargs. channel)
586+ next! (ens_prob_mc. kwargs. progr)
587+ end
588+
589+ @async begin
590+ sol = solve (ens_prob_mc. prob, alg, ensemble_method, trajectories = ntraj)
591+ put! (ens_prob_mc. kwargs. channel, false )
592+ end
593+ end
594+
595+ return sol
596+ end
597+
598+ function _mcsolve_solve_ens (
599+ ens_prob_mc:: TimeEvolutionProblem ,
600+ alg:: OrdinaryDiffEqAlgorithm ,
601+ ensemble_method,
602+ ntraj:: Int ,
603+ )
604+ sol = solve (ens_prob_mc. prob, alg, ensemble_method, trajectories = ntraj)
605+ return sol
590606end
591607
592608function mcsolve (
593- ens_prob_mc:: TimeEvolutionProblem ;
609+ ens_prob_mc:: TimeEvolutionProblem ,
594610 alg:: OrdinaryDiffEqAlgorithm = Tsit5 (),
595611 ntraj:: Int = 1 ,
596612 ensemble_method = EnsembleThreads (),
597- normalize_states:: Union{Val,Bool} = Val (true ),
613+ normalize_states = Val (true ),
598614)
599- if typeof (ensemble_method) <: Union{EnsembleSplitThreads,EnsembleDistributed}
600- @sync begin
601- @async while take! (ens_prob_mc. kwargs. channel)
602- next! (ens_prob_mc. kwargs. progr)
603- end
604-
605- @async begin
606- sol = solve (ens_prob_mc. prob, alg, ensemble_method, trajectories = ntraj)
607- put! (ens_prob_mc. kwargs. channel, false )
608- end
609- end
610- else
611- sol = solve (ens_prob_mc. prob, alg, ensemble_method, trajectories = ntraj)
612- end
615+ sol = _mcsolve_solve_ens (ens_prob_mc, alg, ensemble_method, ntraj)
613616
614617 dims = ens_prob_mc. dims
615618 _sol_1 = sol[:, 1 ]
616619
617- expvals_all = Array {ComplexF64} (undef, length (sol), size (_sol_1. prob. p. expvals)... )
618- states =
619- isempty (_sol_1. prob. kwargs[:saveat ]) ?
620- fill (QuantumObject{Vector{ComplexF64},KetQuantumObject,length (dims)}[], length (sol)) :
621- Vector {Vector{QuantumObject}} (undef, length (sol))
622- jump_times = Vector {Vector{Float64}} (undef, length (sol))
623- jump_which = Vector {Vector{Int16}} (undef, length (sol))
624-
625- foreach (
626- i -> _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which, normalize_states, dims),
627- eachindex (sol),
628- )
629- expvals = dropdims (sum (expvals_all, dims = 1 ), dims = 1 ) ./ length (sol)
620+ expvals_all = mapreduce (i -> sol[:, i]. prob. p. expvals, (x, y) -> cat (x, y, dims = 3 ), eachindex (sol))
621+ states = map (i -> _normalize_state! .(sol[:, i]. u, Ref (dims), normalize_states), eachindex (sol))
622+ jump_times = map (i -> real .(sol[:, i]. prob. p. mcsolve_params. jump_times), eachindex (sol))
623+ jump_which = map (i -> round .(Int, sol[:, i]. prob. p. mcsolve_params. jump_which), eachindex (sol))
624+
625+ expvals = dropdims (sum (expvals_all, dims = 3 ), dims = 3 ) ./ length (sol)
630626
631627 return TimeEvolutionMCSol (
632628 ntraj,
0 commit comments