Skip to content

Commit 84b2c64

Browse files
Working mesolve
1 parent 26f6f4c commit 84b2c64

File tree

10 files changed

+172
-179
lines changed

10 files changed

+172
-179
lines changed

src/QuantumToolbox.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ include("qobj/synonyms.jl")
9595
# time evolution
9696
include("time_evolution/time_evo_parameters.jl")
9797
include("time_evolution/time_evolution.jl")
98-
include("time_evolution/callback_helpers.jl")
98+
include("time_evolution/callback_helpers/callback_helpers.jl")
9999
include("time_evolution/mesolve.jl")
100100
include("time_evolution/lr_mesolve.jl")
101101
include("time_evolution/sesolve.jl")

src/qobj/eigsolve.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,15 @@ function eigsolve_al(
391391
kwargs...,
392392
) where {DT1,HOpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject}}
393393
L_evo = _mesolve_make_L_QobjEvo(H, c_ops)
394-
prob = mesolveProblem(
395-
L_evo,
396-
QuantumObject(ρ0, type = Operator, dims = H.dims),
397-
[zero(T), T];
398-
params = params,
399-
progress_bar = Val(false),
400-
kwargs...,
401-
)
394+
prob =
395+
mesolveProblem(
396+
L_evo,
397+
QuantumObject(ρ0, type = Operator, dims = H.dims),
398+
[zero(T), T];
399+
params = params,
400+
progress_bar = Val(false),
401+
kwargs...,
402+
).prob
402403
integrator = init(prob, alg)
403404

404405
# prog = ProgressUnknown(desc="Applications:", showspeed = true, enabled=progress)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#=
2+
This file contains helper functions for callbacks. The affect! function are defined taking advantage of the Julia struct, which allows to store some cache exclusively for the callback.
3+
=#
4+
5+
include("sesolve_callback_helpers.jl")
6+
include("mesolve_callback_helpers.jl")
7+
include("mcsolve_callback_helpers.jl")
8+
9+
##
10+
11+
# Multiple dispatch depending on the progress_bar and e_ops types
12+
function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method)
13+
cb = _generate_save_callback(e_ops, tlist, progress_bar, method)
14+
return _merge_kwargs_with_callback(kwargs, cb)
15+
end
16+
_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs
17+
18+
function _merge_kwargs_with_callback(kwargs, cb)
19+
kwargs2 =
20+
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) :
21+
merge(kwargs, (callback = cb,))
22+
23+
return kwargs2
24+
end
25+
26+
function _generate_save_callback(e_ops, tlist, progress_bar, method)
27+
e_ops_data = e_ops isa Nothing ? nothing : get_data.(e_ops)
28+
29+
progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing
30+
31+
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
32+
33+
_save_affect! = method(e_ops_data, progr, Ref(1), expvals)
34+
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
35+
end
36+
37+
# When e_ops is Nothing. Common for both mesolve and sesolve
38+
function _save_func(integrator, progr)
39+
next!(progr)
40+
u_modified!(integrator, false)
41+
return nothing
42+
end
43+
44+
# When progr is Nothing. Common for both mesolve and sesolve
45+
function _save_func(integrator, progr::Nothing)
46+
u_modified!(integrator, false)
47+
return nothing
48+
end
49+
50+
##
51+
52+
# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`.
53+
function _se_me_sse_get_expvals(sol::AbstractODESolution)
54+
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
55+
if hasproperty(kwargs, :callback)
56+
return _se_me_sse_get_expvals(kwargs.callback)
57+
else
58+
return nothing
59+
end
60+
end
61+
function _se_me_sse_get_expvals(cb::CallbackSet)
62+
_cb = cb.discrete_callbacks[1]
63+
return _se_me_sse_get_expvals(_cb)
64+
end
65+
_se_me_sse_get_expvals(cb::DiscreteCallback) =
66+
if (cb.affect! isa SaveFuncSESolve) || (cb.affect! isa SaveFuncMESolve)
67+
return cb.affect!.expvals
68+
else
69+
return nothing
70+
end
71+
_se_me_sse_get_expvals(cb::ContinuousCallback) = nothing

src/time_evolution/callback_helpers.jl renamed to src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,7 @@
11
#=
2-
This file contains helper functions for callbacks. The affect! function are defined taking advantage of the Julia struct, which allows to store some cache exclusively for the callback.
2+
Helper functions for the mcsolve callbacks.
33
=#
44

5-
########## SESOLVE ##########
6-
7-
struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
8-
e_ops::TE
9-
progr::PT
10-
iter::IT
11-
expvals::TEXPV
12-
end
13-
14-
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
15-
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func_sesolve(integrator, f.progr)
16-
17-
##
18-
19-
# When e_ops is Nothing
20-
function _save_func_sesolve(integrator, progr)
21-
next!(progr)
22-
u_modified!(integrator, false)
23-
return nothing
24-
end
25-
26-
# When progr is Nothing
27-
function _save_func_sesolve(integrator, progr::Nothing)
28-
u_modified!(integrator, false)
29-
return nothing
30-
end
31-
32-
# When e_ops is a list of operators
33-
function _save_func_sesolve(integrator, e_ops, progr, iter, expvals)
34-
ψ = integrator.u
35-
_expect = op -> dot(ψ, op, ψ)
36-
@. expvals[:, iter[]] = _expect(e_ops)
37-
iter[] += 1
38-
39-
return _save_func_sesolve(integrator, progr)
40-
end
41-
42-
function _generate_sesolve_callback(e_ops, tlist, progress_bar)
43-
e_ops_data = e_ops isa Nothing ? nothing : get_data.(e_ops)
44-
45-
progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing
46-
47-
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
48-
49-
_save_affect! = SaveFuncSESolve(e_ops_data, progr, Ref(1), expvals)
50-
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
51-
end
52-
53-
function _sesolve_get_expvals(sol::AbstractODESolution)
54-
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
55-
if hasproperty(kwargs, :callback)
56-
return _sesolve_get_expvals(kwargs.callback)
57-
else
58-
return nothing
59-
end
60-
end
61-
function _sesolve_get_expvals(cb::CallbackSet)
62-
_cb = cb.discrete_callbacks[1]
63-
return _sesolve_get_expvals(_cb)
64-
end
65-
_sesolve_get_expvals(cb::DiscreteCallback) =
66-
if cb.affect! isa SaveFuncSESolve
67-
return cb.affect!.expvals
68-
else
69-
return nothing
70-
end
71-
_sesolve_get_expvals(cb::ContinuousCallback) = nothing
72-
73-
########## MCSOLVE ##########
74-
755
struct SaveFuncMCSolve{TE,IT,TEXPV}
766
e_ops::TE
777
iter::IT
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#=
2+
Helper functions for the mesolve callbacks.
3+
=#
4+
5+
struct SaveFuncMESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
6+
e_ops::TE
7+
progr::PT
8+
iter::IT
9+
expvals::TEXPV
10+
end
11+
12+
(f::SaveFuncMESolve)(integrator) = _save_func_mesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
13+
(f::SaveFuncMESolve{Nothing})(integrator) = _save_func(integrator, f.progr)
14+
15+
##
16+
17+
# When e_ops is a list of operators
18+
function _save_func_mesolve(integrator, e_ops, progr, iter, expvals)
19+
# This is equivalent to tr(op * ρ), when both are matrices.
20+
# The advantage of using this convention is that We don't need
21+
# to reshape u to make it a matrix, but we reshape the e_ops once.
22+
23+
ρ = integrator.u
24+
_expect = op -> dot(op, ρ)
25+
@. expvals[:, iter[]] = _expect(e_ops)
26+
iter[] += 1
27+
28+
return _save_func(integrator, progr)
29+
end
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#=
2+
Helper functions for the sesolve callbacks.
3+
=#
4+
5+
struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
6+
e_ops::TE
7+
progr::PT
8+
iter::IT
9+
expvals::TEXPV
10+
end
11+
12+
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
13+
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve
14+
15+
##
16+
17+
# When e_ops is a list of operators
18+
function _save_func_sesolve(integrator, e_ops, progr, iter, expvals)
19+
ψ = integrator.u
20+
_expect = op -> dot(ψ, op, ψ)
21+
@. expvals[:, iter[]] = _expect(e_ops)
22+
iter[] += 1
23+
24+
return _save_func(integrator, progr)
25+
end

0 commit comments

Comments
 (0)