Skip to content

Commit 18549c1

Browse files
Add callbacks helpers
1 parent 5579426 commit 18549c1

File tree

5 files changed

+153
-126
lines changed

5 files changed

+153
-126
lines changed

src/QuantumToolbox.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ include("qobj/synonyms.jl")
9292
# time evolution
9393
include("time_evolution/time_evo_parameters.jl")
9494
include("time_evolution/time_evolution.jl")
95+
include("time_evolution/callback_helpers.jl")
9596
include("time_evolution/mesolve.jl")
9697
include("time_evolution/lr_mesolve.jl")
9798
include("time_evolution/sesolve.jl")
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#=
2+
This file contains helper functions for callbacks. The affect! function are defined taking advantage of the Julia struct, which allows to store some cache exclusively for the callback.
3+
=#
4+
5+
########## SESOLVE ##########
6+
7+
struct SaveFuncSESolve{T1,T2}
8+
e_ops::T1
9+
is_empty_e_ops::T2
10+
end
11+
12+
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.is_empty_e_ops)
13+
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func_sesolve(integrator)
14+
15+
# When e_ops is Nothing
16+
function _save_func_sesolve(integrator)
17+
next!(integrator.p.progr)
18+
u_modified!(integrator, false)
19+
return nothing
20+
end
21+
22+
# When e_ops is a list of operators
23+
function _save_func_sesolve(integrator, e_ops, is_empty_e_ops)
24+
expvals = integrator.p.expvals
25+
progr = integrator.p.progr
26+
if !is_empty_e_ops
27+
ψ = integrator.u
28+
_expect = op -> dot(ψ, get_data(op), ψ)
29+
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
30+
end
31+
return _save_func_sesolve(integrator)
32+
end
33+
34+
function _generate_sesolve_callback(e_ops, tlist)
35+
is_empty_e_ops = e_ops isa Nothing ? true : isempty(e_ops)
36+
_save_affect! = SaveFuncSESolve(e_ops, is_empty_e_ops)
37+
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
38+
end
39+
40+
########## MCSOLVE ##########
41+
42+
struct SaveFuncMCSolve{T1,T2}
43+
e_ops::T1
44+
is_empty_e_ops::T2
45+
end
46+
47+
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.is_empty_e_ops)
48+
49+
struct LindbladJumpAffect!{T1,T2}
50+
c_ops::T1
51+
c_ops_herm::T2
52+
end
53+
54+
(f::LindbladJumpAffect!)(integrator) = _lindblad_jump_affect!(integrator, f.c_ops, f.c_ops_herm)
55+
56+
##
57+
58+
function _save_func_mcsolve(integrator, e_ops, is_empty_e_ops)
59+
expvals = integrator.p.expvals
60+
progr = integrator.p.progr
61+
cache_mc = integrator.p.mcsolve_params.cache_mc
62+
if !is_empty_e_ops
63+
copyto!(cache_mc, integrator.u)
64+
normalize!(cache_mc)
65+
ψ = cache_mc
66+
_expect = op -> dot(ψ, get_data(op), ψ)
67+
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
68+
end
69+
next!(progr)
70+
u_modified!(integrator, false)
71+
return nothing
72+
end
73+
74+
function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
75+
c_ops_data = get_data.(c_ops)
76+
c_ops_herm_data = map(op -> op' * op, c_ops_data)
77+
78+
_affect! = LindbladJumpAffect!(c_ops_data, c_ops_herm_data)
79+
80+
if jump_callback isa DiscreteLindbladJumpCallback
81+
cb1 = DiscreteCallback(_mcsolve_discrete_condition, _affect!, save_positions = (false, false))
82+
else
83+
cb1 = ContinuousCallback(
84+
_mcsolve_continuous_condition,
85+
_affect!,
86+
nothing,
87+
interp_points = jump_callback.interp_points,
88+
save_positions = (false, false),
89+
)
90+
end
91+
92+
if e_ops isa Nothing
93+
kwargs2 =
94+
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, kwargs.callback),)) :
95+
merge(kwargs, (callback = cb1,))
96+
return kwargs2
97+
else
98+
is_empty_e_ops = isempty(e_ops)
99+
_save_affect! = SaveFuncMCSolve(e_ops, is_empty_e_ops)
100+
cb2 = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
101+
kwargs2 =
102+
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
103+
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
104+
return kwargs2
105+
end
106+
end
107+
108+
function _lindblad_jump_affect!(integrator, c_ops, c_ops_herm)
109+
params = integrator.p
110+
cache_mc = params.mcsolve_params.cache_mc
111+
weights_mc = params.mcsolve_params.weights_mc
112+
cumsum_weights_mc = params.mcsolve_params.cumsum_weights_mc
113+
random_n = params.mcsolve_params.random_n
114+
jump_times = params.mcsolve_params.jump_times
115+
jump_which = params.mcsolve_params.jump_which
116+
jump_times_which_idx = params.mcsolve_params.jump_times_which_idx
117+
traj_rng = params.mcsolve_params.traj_rng
118+
ψ = integrator.u
119+
120+
@inbounds for i in eachindex(weights_mc)
121+
weights_mc[i] = real(dot(ψ, c_ops_herm[i], ψ))
122+
end
123+
cumsum!(cumsum_weights_mc, weights_mc)
124+
r = rand(traj_rng) * sum(real, weights_mc)
125+
collapse_idx = getindex(1:length(weights_mc), findfirst(x -> real(x) > r, cumsum_weights_mc))
126+
mul!(cache_mc, c_ops[collapse_idx], ψ)
127+
normalize!(cache_mc)
128+
copyto!(integrator.u, cache_mc)
129+
130+
@inbounds random_n[1] = rand(traj_rng)
131+
132+
@inbounds idx = round(Int, real(jump_times_which_idx[1]))
133+
@inbounds jump_times[idx] = integrator.t
134+
@inbounds jump_which[idx] = collapse_idx
135+
@inbounds jump_times_which_idx[1] += 1
136+
@inbounds if real(jump_times_which_idx[1]) > length(jump_times)
137+
resize!(jump_times, length(jump_times) + JUMP_TIMES_WHICH_INIT_SIZE)
138+
resize!(jump_which, length(jump_which) + JUMP_TIMES_WHICH_INIT_SIZE)
139+
end
140+
u_modified!(integrator, true)
141+
return nothing
142+
end
143+
144+
_mcsolve_continuous_condition(u, t, integrator) =
145+
@inbounds real(integrator.p.mcsolve_params.random_n[1]) - real(dot(u, u))
146+
147+
_mcsolve_discrete_condition(u, t, integrator) =
148+
@inbounds real(dot(u, u)) < real(integrator.p.mcsolve_params.random_n[1])

src/time_evolution/mcsolve.jl

Lines changed: 3 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,6 @@
11
export mcsolveProblem, mcsolveEnsembleProblem, mcsolve
22
export ContinuousLindbladJumpCallback, DiscreteLindbladJumpCallback
33

4-
const jump_times_which_init_size = 200
5-
6-
function _save_func_mcsolve(integrator, e_ops, is_empty_e_ops)
7-
expvals = integrator.p.expvals
8-
progr = integrator.p.progr
9-
cache_mc = integrator.p.mcsolve_params.cache_mc
10-
if !is_empty_e_ops
11-
copyto!(cache_mc, integrator.u)
12-
normalize!(cache_mc)
13-
ψ = cache_mc
14-
_expect = op -> dot(ψ, get_data(op), ψ)
15-
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
16-
end
17-
next!(progr)
18-
u_modified!(integrator, false)
19-
return nothing
20-
end
21-
22-
function LindbladJumpAffect!(integrator, c_ops, c_ops_herm)
23-
params = integrator.p
24-
cache_mc = params.mcsolve_params.cache_mc
25-
weights_mc = params.mcsolve_params.weights_mc
26-
cumsum_weights_mc = params.mcsolve_params.cumsum_weights_mc
27-
random_n = params.mcsolve_params.random_n
28-
jump_times = params.mcsolve_params.jump_times
29-
jump_which = params.mcsolve_params.jump_which
30-
jump_times_which_idx = params.mcsolve_params.jump_times_which_idx
31-
traj_rng = params.mcsolve_params.traj_rng
32-
ψ = integrator.u
33-
34-
@inbounds for i in eachindex(weights_mc)
35-
weights_mc[i] = real(dot(ψ, c_ops_herm[i], ψ))
36-
end
37-
cumsum!(cumsum_weights_mc, weights_mc)
38-
r = rand(traj_rng) * sum(real, weights_mc)
39-
collapse_idx = getindex(1:length(weights_mc), findfirst(x -> real(x) > r, cumsum_weights_mc))
40-
mul!(cache_mc, c_ops[collapse_idx], ψ)
41-
normalize!(cache_mc)
42-
copyto!(integrator.u, cache_mc)
43-
44-
@inbounds random_n[1] = rand(traj_rng)
45-
46-
@inbounds idx = round(Int, real(jump_times_which_idx[1]))
47-
@inbounds jump_times[idx] = integrator.t
48-
@inbounds jump_which[idx] = collapse_idx
49-
@inbounds jump_times_which_idx[1] += 1
50-
@inbounds if real(jump_times_which_idx[1]) > length(jump_times)
51-
resize!(jump_times, length(jump_times) + jump_times_which_init_size)
52-
resize!(jump_which, length(jump_which) + jump_times_which_init_size)
53-
end
54-
end
55-
56-
_mcsolve_continuous_condition(u, t, integrator) =
57-
@inbounds real(integrator.p.mcsolve_params.random_n[1]) - real(dot(u, u))
58-
59-
_mcsolve_discrete_condition(u, t, integrator) =
60-
@inbounds real(dot(u, u)) < real(integrator.p.mcsolve_params.random_n[1])
61-
624
function _mcsolve_prob_func(prob, i, repeat, global_rng, seeds)
635
params = prob.p
646

@@ -144,40 +86,6 @@ function _normalize_state!(u, dims, normalize_states)
14486
return QuantumObject(u, type = Ket, dims = dims)
14587
end
14688

147-
function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
148-
c_ops_data = get_data.(c_ops)
149-
c_ops_herm_data = map(op -> op' * op, c_ops_data)
150-
151-
_affect = integrator -> LindbladJumpAffect!(integrator, c_ops_data, c_ops_herm_data)
152-
153-
if jump_callback isa DiscreteLindbladJumpCallback
154-
cb1 = DiscreteCallback(_mcsolve_discrete_condition, _affect, save_positions = (false, false))
155-
else
156-
cb1 = ContinuousCallback(
157-
_mcsolve_continuous_condition,
158-
_affect,
159-
nothing,
160-
interp_points = jump_callback.interp_points,
161-
save_positions = (false, false),
162-
)
163-
end
164-
165-
if e_ops isa Nothing
166-
kwargs2 =
167-
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, kwargs.callback),)) :
168-
merge(kwargs, (callback = cb1,))
169-
return kwargs2
170-
else
171-
is_empty_e_ops = isempty(e_ops)
172-
f = integrator -> _save_func_mcsolve(integrator, e_ops, is_empty_e_ops)
173-
cb2 = PresetTimeCallback(tlist, f, save_positions = (false, false))
174-
kwargs2 =
175-
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
176-
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
177-
return kwargs2
178-
end
179-
end
180-
18189
function _mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops)
18290
c_ops isa Nothing && return QobjEvo(H)
18391
return QobjEvo(H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2)
@@ -302,8 +210,8 @@ function mcsolveProblem(
302210
weights_mc = similar(ψ0.data, T, length(c_ops)) # It should be a Float64 Vector, but we have to keep the same type for all the parameters due to SciMLStructures.jl
303211
cumsum_weights_mc = similar(weights_mc)
304212

305-
jump_times = similar(ψ0.data, T, jump_times_which_init_size)
306-
jump_which = similar(ψ0.data, T, jump_times_which_init_size)
213+
jump_times = similar(ψ0.data, T, JUMP_TIMES_WHICH_INIT_SIZE)
214+
jump_which = similar(ψ0.data, T, JUMP_TIMES_WHICH_INIT_SIZE)
307215
jump_times_which_idx = T[1] # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl
308216

309217
random_n = similar(ψ0.data, T, 1) # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl.
@@ -439,7 +347,7 @@ function mcsolveEnsembleProblem(
439347
)
440348

441349
ensemble_prob = TimeEvolutionProblem(
442-
EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true),
350+
EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
443351
prob_mc.times,
444352
prob_mc.dims,
445353
(progr = _output_func[2], channel = _output_func[3]),

src/time_evolution/sesolve.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,5 @@
11
export sesolveProblem, sesolve
22

3-
# When e_ops is Nothing
4-
function _save_func_sesolve(integrator)
5-
next!(integrator.p.progr)
6-
u_modified!(integrator, false)
7-
return nothing
8-
end
9-
10-
# When e_ops is a list of operators
11-
function _save_func_sesolve(integrator, e_ops, is_empty_e_ops)
12-
expvals = integrator.p.expvals
13-
progr = integrator.p.progr
14-
if !is_empty_e_ops
15-
ψ = integrator.u
16-
_expect = op -> dot(ψ, get_data(op), ψ)
17-
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
18-
end
19-
return _save_func_sesolve(integrator)
20-
end
21-
22-
# Generate the callback depending on the e_ops type
23-
function _generate_sesolve_callback(e_ops::Nothing, tlist)
24-
f = integrator -> _save_func_sesolve(integrator)
25-
return PresetTimeCallback(tlist, f, save_positions = (false, false))
26-
end
27-
28-
function _generate_sesolve_callback(e_ops, tlist)
29-
is_empty_e_ops = isempty(e_ops)
30-
f = integrator -> _save_func_sesolve(integrator, e_ops, is_empty_e_ops)
31-
return PresetTimeCallback(tlist, f, save_positions = (false, false))
32-
end
33-
343
function _merge_sesolve_kwargs_with_callback(kwargs, cb)
354
kwargs2 =
365
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb),)) :

src/time_evolution/time_evolution.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ export liouvillian_floquet, liouvillian_generalized
44

55
const DEFAULT_ODE_SOLVER_OPTIONS = (abstol = 1e-8, reltol = 1e-6, save_everystep = false, save_end = true)
66
const DEFAULT_SDE_SOLVER_OPTIONS = (abstol = 1e-2, reltol = 1e-2, save_everystep = false, save_end = true)
7+
const JUMP_TIMES_WHICH_INIT_SIZE = 200
78

89
@doc raw"""
910
struct TimeEvolutionProblem

0 commit comments

Comments
 (0)