Skip to content

Commit 9b23c6f

Browse files
Add normalize_states option in mcsolve (#285)
1 parent 2adcd8a commit 9b23c6f

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

src/time_evolution/mcsolve.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,16 @@ _mcsolve_dispatch_output_func() = _mcsolve_output_func
109109
_mcsolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _mcsolve_output_func_progress
110110
_mcsolve_dispatch_output_func(::EnsembleDistributed) = _mcsolve_output_func_distributed
111111

112-
function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which)
112+
function _normalize_state!(u, dims, normalize_states)
113+
getVal(normalize_states) && normalize!(u)
114+
return QuantumObject(u, dims = dims)
115+
end
116+
117+
function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states)
113118
sol_i = sol[:, i]
114-
!isempty(sol_i.prob.kwargs[:saveat]) ?
115-
states[i] = [QuantumObject(normalize!(sol_i.u[i]), dims = sol_i.prob.p.Hdims) for i in 1:length(sol_i.u)] : nothing
119+
dims = sol_i.prob.p.Hdims
120+
!isempty(sol_i.prob.kwargs[:saveat]) ? states[i] = map(u -> _normalize_state!(u, dims, normalize_states), sol_i.u) :
121+
nothing
116122

117123
copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals)
118124
jump_times[i] = sol_i.prob.p.jump_times
@@ -461,6 +467,7 @@ end
461467
prob_func::Function = _mcsolve_prob_func,
462468
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
463469
progress_bar::Union{Val,Bool} = Val(true),
470+
normalize_states::Union{Val,Bool} = Val(true),
464471
kwargs...,
465472
)
466473
@@ -514,6 +521,7 @@ If the environmental measurements register a quantum jump, the wave function und
514521
- `prob_func`: Function to use for generating the ODEProblem.
515522
- `output_func`: Function to use for generating the output of a single trajectory.
516523
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
524+
- `normalize_states`: Whether to normalize the states. Default to `Val(true)`.
517525
- `kwargs`: The keyword arguments for the ODEProblem.
518526
519527
# Notes
@@ -544,6 +552,7 @@ function mcsolve(
544552
prob_func::Function = _mcsolve_prob_func,
545553
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
546554
progress_bar::Union{Val,Bool} = Val(true),
555+
normalize_states::Union{Val,Bool} = Val(true),
547556
kwargs...,
548557
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
549558
ens_prob_mc = mcsolveEnsembleProblem(
@@ -564,14 +573,21 @@ function mcsolve(
564573
kwargs...,
565574
)
566575

567-
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
576+
return mcsolve(
577+
ens_prob_mc;
578+
alg = alg,
579+
ntraj = ntraj,
580+
ensemble_method = ensemble_method,
581+
normalize_states = normalize_states,
582+
)
568583
end
569584

570585
function mcsolve(
571586
ens_prob_mc::EnsembleProblem;
572587
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
573588
ntraj::Int = 1,
574589
ensemble_method = EnsembleThreads(),
590+
normalize_states::Union{Val,Bool} = Val(true),
575591
)
576592
try
577593
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
@@ -589,7 +605,10 @@ function mcsolve(
589605
jump_times = Vector{Vector{Float64}}(undef, length(sol))
590606
jump_which = Vector{Vector{Int16}}(undef, length(sol))
591607

592-
foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol))
608+
foreach(
609+
i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states),
610+
eachindex(sol),
611+
)
593612
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)
594613

595614
return TimeEvolutionMCSol(

0 commit comments

Comments
 (0)