Skip to content

Commit 42f827e

Browse files
Fix type-instabilities for mcsolve
1 parent 3f74d4a commit 42f827e

File tree

3 files changed

+57
-61
lines changed

3 files changed

+57
-61
lines changed

src/time_evolution/mcsolve.jl

Lines changed: 55 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,7 @@ end
141141

142142
function _normalize_state!(u, dims, normalize_states)
143143
getVal(normalize_states) && normalize!(u)
144-
return QuantumObject(u, dims = dims)
145-
end
146-
147-
function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states, dims)
148-
sol_i = sol[:, i]
149-
!isempty(sol_i.prob.kwargs[:saveat]) ? states[i] = map(u -> _normalize_state!(u, dims, normalize_states), sol_i.u) :
150-
nothing
151-
152-
copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals)
153-
jump_times[i] = sol_i.prob.p.mcsolve_params.jump_times
154-
return jump_which[i] = sol_i.prob.p.mcsolve_params.jump_which
144+
return QuantumObject(u, type = Ket, dims = dims)
155145
end
156146

157147
function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
@@ -308,15 +298,15 @@ function mcsolveProblem(
308298
kwargs2 = merge(default_values, kwargs)
309299
kwargs3 = _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs2)
310300

311-
cache_mc = similar(ψ0.data)
312-
weights_mc = similar(ψ0.data, length(c_ops)) # It should be a Float64 Vector, but we have to keep the same type for all the parameters due to SciMLStructures.jl
301+
cache_mc = similar(ψ0.data, T)
302+
weights_mc = similar(ψ0.data, T, length(c_ops)) # It should be a Float64 Vector, but we have to keep the same type for all the parameters due to SciMLStructures.jl
313303
cumsum_weights_mc = similar(weights_mc)
314304

315-
jump_times = similar(ψ0.data, jump_times_which_init_size)
316-
jump_which = similar(ψ0.data, jump_times_which_init_size)
305+
jump_times = similar(ψ0.data, T, jump_times_which_init_size)
306+
jump_which = similar(ψ0.data, T, jump_times_which_init_size)
317307
jump_times_which_idx = T[1] # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl
318308

319-
random_n = similar(ψ0.data, 1) # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl.
309+
random_n = similar(ψ0.data, T, 1) # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl.
320310
random_n[1] = rand(rng)
321311

322312
progr = ProgressBar(length(tlist), enable = false)
@@ -432,24 +422,24 @@ function mcsolveEnsembleProblem(
432422
output_func::Union{Tuple,Nothing} = nothing,
433423
kwargs...,
434424
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
425+
_prob_func = prob_func isa Nothing ? _mcsolve_dispatch_prob_func(rng, ntraj) : prob_func
426+
_output_func =
427+
output_func isa Nothing ? _mcsolve_dispatch_output_func(ensemble_method, progress_bar, ntraj) : output_func
428+
435429
prob_mc = mcsolveProblem(
436430
H,
437431
ψ0,
438432
tlist,
439433
c_ops;
440434
e_ops = e_ops,
441435
params = params,
442-
rng = deepcopy(rng), # By deepcopying, we avoid to also count the initialization of mcsolveProblem
436+
rng = rng,
443437
jump_callback = jump_callback,
444438
kwargs...,
445439
)
446440

447-
_prob_func = prob_func isa Nothing ? _mcsolve_dispatch_prob_func(rng, ntraj) : prob_func
448-
_output_func =
449-
output_func isa Nothing ? _mcsolve_dispatch_output_func(ensemble_method, progress_bar, ntraj) : output_func
450-
451441
ensemble_prob = TimeEvolutionProblem(
452-
EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
442+
EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true),
453443
prob_mc.times,
454444
prob_mc.dims,
455445
(progr = _output_func[2], channel = _output_func[3]),
@@ -558,7 +548,7 @@ function mcsolve(
558548
jump_callback::TJC = ContinuousLindbladJumpCallback(),
559549
progress_bar::Union{Val,Bool} = Val(true),
560550
prob_func::Union{Function,Nothing} = nothing,
561-
output_func::Union{Function,Nothing} = nothing,
551+
output_func::Union{Tuple,Nothing} = nothing,
562552
normalize_states::Union{Val,Bool} = Val(true),
563553
kwargs...,
564554
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
@@ -580,53 +570,59 @@ function mcsolve(
580570
kwargs...,
581571
)
582572

583-
return mcsolve(
584-
ens_prob_mc;
585-
alg = alg,
586-
ntraj = ntraj,
587-
ensemble_method = ensemble_method,
588-
normalize_states = normalize_states,
589-
)
573+
return mcsolve(ens_prob_mc, alg, ntraj, ensemble_method, normalize_states)
574+
end
575+
576+
function _mcsolve_solve_ens(
577+
ens_prob_mc::TimeEvolutionProblem,
578+
alg::OrdinaryDiffEqAlgorithm,
579+
ensemble_method::ET,
580+
ntraj::Int,
581+
) where {ET<:Union{EnsembleSplitThreads,EnsembleDistributed}}
582+
sol = nothing
583+
584+
@sync begin
585+
@async while take!(ens_prob_mc.kwargs.channel)
586+
next!(ens_prob_mc.kwargs.progr)
587+
end
588+
589+
@async begin
590+
sol = solve(ens_prob_mc.prob, alg, ensemble_method, trajectories = ntraj)
591+
put!(ens_prob_mc.kwargs.channel, false)
592+
end
593+
end
594+
595+
return sol
596+
end
597+
598+
function _mcsolve_solve_ens(
599+
ens_prob_mc::TimeEvolutionProblem,
600+
alg::OrdinaryDiffEqAlgorithm,
601+
ensemble_method,
602+
ntraj::Int,
603+
)
604+
sol = solve(ens_prob_mc.prob, alg, ensemble_method, trajectories = ntraj)
605+
return sol
590606
end
591607

592608
function mcsolve(
593-
ens_prob_mc::TimeEvolutionProblem;
609+
ens_prob_mc::TimeEvolutionProblem,
594610
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
595611
ntraj::Int = 1,
596612
ensemble_method = EnsembleThreads(),
597-
normalize_states::Union{Val,Bool} = Val(true),
613+
normalize_states = Val(true),
598614
)
599-
if typeof(ensemble_method) <: Union{EnsembleSplitThreads,EnsembleDistributed}
600-
@sync begin
601-
@async while take!(ens_prob_mc.kwargs.channel)
602-
next!(ens_prob_mc.kwargs.progr)
603-
end
604-
605-
@async begin
606-
sol = solve(ens_prob_mc.prob, alg, ensemble_method, trajectories = ntraj)
607-
put!(ens_prob_mc.kwargs.channel, false)
608-
end
609-
end
610-
else
611-
sol = solve(ens_prob_mc.prob, alg, ensemble_method, trajectories = ntraj)
612-
end
615+
sol = _mcsolve_solve_ens(ens_prob_mc, alg, ensemble_method, ntraj)
613616

614617
dims = ens_prob_mc.dims
615618
_sol_1 = sol[:, 1]
616619

617-
expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
618-
states =
619-
isempty(_sol_1.prob.kwargs[:saveat]) ?
620-
fill(QuantumObject{Vector{ComplexF64},KetQuantumObject,length(dims)}[], length(sol)) :
621-
Vector{Vector{QuantumObject}}(undef, length(sol))
622-
jump_times = Vector{Vector{Float64}}(undef, length(sol))
623-
jump_which = Vector{Vector{Int16}}(undef, length(sol))
624-
625-
foreach(
626-
i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states, dims),
627-
eachindex(sol),
628-
)
629-
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)
620+
expvals_all = mapreduce(i -> sol[:, i].prob.p.expvals, (x, y) -> cat(x, y, dims = 3), eachindex(sol))
621+
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
622+
jump_times = map(i -> real.(sol[:, i].prob.p.mcsolve_params.jump_times), eachindex(sol))
623+
jump_which = map(i -> round.(Int, sol[:, i].prob.p.mcsolve_params.jump_which), eachindex(sol))
624+
625+
expvals = dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
630626

631627
return TimeEvolutionMCSol(
632628
ntraj,

src/time_evolution/time_evolution_dynamical.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,5 +731,5 @@ function dsf_mcsolve(
731731
kwargs...,
732732
)
733733

734-
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
734+
return mcsolve(ens_prob_mc, alg, ntraj, ensemble_method)
735735
end

test/core-test/time_evolution.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@
393393
@test sol_mc1.jump_times sol_mc2.jump_times atol = 1e-10
394394
@test sol_mc1.jump_which sol_mc2.jump_which atol = 1e-10
395395

396-
@test sol_mc1.expect_all sol_mc3.expect_all[1:500, :, :] atol = 1e-10
396+
@test sol_mc1.expect_all sol_mc3.expect_all[:, :, 1:500] atol = 1e-10
397397

398398
@test sol_sse1.expect sol_sse2.expect atol = 1e-10
399399
@test sol_sse1.expect_all sol_sse2.expect_all atol = 1e-10

0 commit comments

Comments
 (0)