@@ -109,10 +109,16 @@ _mcsolve_dispatch_output_func() = _mcsolve_output_func
109109_mcsolve_dispatch_output_func (:: ET ) where {ET<: Union{EnsembleSerial,EnsembleThreads} } = _mcsolve_output_func_progress
110110_mcsolve_dispatch_output_func (:: EnsembleDistributed ) = _mcsolve_output_func_distributed
111111
112- function _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which)
112+ function _normalize_state! (u, dims, normalize_states)
113+ getVal (normalize_states) && normalize! (u)
114+ return QuantumObject (u, dims = dims)
115+ end
116+
117+ function _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which, normalize_states)
113118 sol_i = sol[:, i]
114- ! isempty (sol_i. prob. kwargs[:saveat ]) ?
115- states[i] = [QuantumObject (normalize! (sol_i. u[i]), dims = sol_i. prob. p. Hdims) for i in 1 : length (sol_i. u)] : nothing
119+ dims = sol_i. prob. p. Hdims
120+ ! isempty (sol_i. prob. kwargs[:saveat ]) ? states[i] = map (u -> _normalize_state! (u, dims, normalize_states), sol_i. u) :
121+ nothing
116122
117123 copyto! (view (expvals_all, i, :, :), sol_i. prob. p. expvals)
118124 jump_times[i] = sol_i. prob. p. jump_times
461467 prob_func::Function = _mcsolve_prob_func,
462468 output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
463469 progress_bar::Union{Val,Bool} = Val(true),
470+ normalize_states::Union{Val,Bool} = Val(true),
464471 kwargs...,
465472 )
466473
@@ -514,6 +521,7 @@ If the environmental measurements register a quantum jump, the wave function und
514521- `prob_func`: Function to use for generating the ODEProblem.
515522- `output_func`: Function to use for generating the output of a single trajectory.
516523- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
524+ - `normalize_states`: Whether to normalize the states. Default to `Val(true)`.
517525- `kwargs`: The keyword arguments for the ODEProblem.
518526
519527# Notes
@@ -544,6 +552,7 @@ function mcsolve(
544552 prob_func:: Function = _mcsolve_prob_func,
545553 output_func:: Function = _mcsolve_dispatch_output_func (ensemble_method),
546554 progress_bar:: Union{Val,Bool} = Val (true ),
555+ normalize_states:: Union{Val,Bool} = Val (true ),
547556 kwargs... ,
548557) where {DT1,DT2,TJC<: LindbladJumpCallbackType }
549558 ens_prob_mc = mcsolveEnsembleProblem (
@@ -564,14 +573,21 @@ function mcsolve(
564573 kwargs... ,
565574 )
566575
567- return mcsolve (ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
576+ return mcsolve (
577+ ens_prob_mc;
578+ alg = alg,
579+ ntraj = ntraj,
580+ ensemble_method = ensemble_method,
581+ normalize_states = normalize_states,
582+ )
568583end
569584
570585function mcsolve (
571586 ens_prob_mc:: EnsembleProblem ;
572587 alg:: OrdinaryDiffEqAlgorithm = Tsit5 (),
573588 ntraj:: Int = 1 ,
574589 ensemble_method = EnsembleThreads (),
590+ normalize_states:: Union{Val,Bool} = Val (true ),
575591)
576592 try
577593 sol = solve (ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
@@ -589,7 +605,10 @@ function mcsolve(
589605 jump_times = Vector {Vector{Float64}} (undef, length (sol))
590606 jump_which = Vector {Vector{Int16}} (undef, length (sol))
591607
592- foreach (i -> _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which), eachindex (sol))
608+ foreach (
609+ i -> _mcsolve_generate_statistics (sol, i, states, expvals_all, jump_times, jump_which, normalize_states),
610+ eachindex (sol),
611+ )
593612 expvals = dropdims (sum (expvals_all, dims = 1 ), dims = 1 ) ./ length (sol)
594613
595614 return TimeEvolutionMCSol (
0 commit comments