Skip to content

Commit 89b072c

Browse files
Remove TimeEvolutionParameters (type-unstable)
1 parent 81d627c commit 89b072c

File tree

7 files changed

+231
-230
lines changed

7 files changed

+231
-230
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2121
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2222
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2323
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
24-
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
2524
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2625
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2726
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
@@ -54,7 +53,6 @@ Random = "1"
5453
Reexport = "1"
5554
SciMLBase = "2"
5655
SciMLOperators = "0.3"
57-
SciMLStructures = "1.5.0"
5856
SparseArrays = "1"
5957
SpecialFunctions = "2"
6058
StaticArraysCore = "1"

docs/src/resources/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ qeye
181181
## [Time evolution](@id doc-API:Time-evolution)
182182

183183
```@docs
184-
TimeEvolutionParameters
185184
TimeEvolutionProblem
186185
TimeEvolutionSol
187186
TimeEvolutionMCSol

src/QuantumToolbox.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ include("qobj/superoperators.jl")
9393
include("qobj/synonyms.jl")
9494

9595
# time evolution
96-
include("time_evolution/time_evo_parameters.jl")
9796
include("time_evolution/time_evolution.jl")
9897
include("time_evolution/callback_helpers/callback_helpers.jl")
9998
include("time_evolution/mesolve.jl")

src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 212 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,47 @@ end
1010

1111
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals)
1212

13-
struct LindbladJump{T1,T2}
13+
struct LindbladJump{
14+
T1,
15+
T2,
16+
RNGType<:AbstractRNG,
17+
RandT,
18+
CT<:AbstractVector,
19+
WT<:AbstractVector,
20+
JTT<:AbstractVector,
21+
JWT<:AbstractVector,
22+
JTWIT,
23+
}
1424
c_ops::T1
1525
c_ops_herm::T2
26+
traj_rng::RNGType
27+
random_n::RandT
28+
cache_mc::CT
29+
weights_mc::WT
30+
cumsum_weights_mc::WT
31+
jump_times::JTT
32+
jump_which::JWT
33+
jump_times_which_idx::JTWIT
1634
end
1735

18-
(f::LindbladJump)(integrator) = _lindblad_jump_affect!(integrator, f.c_ops, f.c_ops_herm)
36+
(f::LindbladJump)(integrator) = _lindblad_jump_affect!(
37+
integrator,
38+
f.c_ops,
39+
f.c_ops_herm,
40+
f.traj_rng,
41+
f.random_n,
42+
f.cache_mc,
43+
f.weights_mc,
44+
f.cumsum_weights_mc,
45+
f.jump_times,
46+
f.jump_which,
47+
f.jump_times_which_idx,
48+
)
1949

2050
##
2151

2252
function _save_func_mcsolve(integrator, e_ops, iter, expvals)
23-
cache_mc = integrator.p.mcsolve_params.cache_mc
53+
cache_mc = _mc_get_jump_callback(integrator).affect!.cache_mc
2454

2555
copyto!(cache_mc, integrator.u)
2656
normalize!(cache_mc)
@@ -33,11 +63,32 @@ function _save_func_mcsolve(integrator, e_ops, iter, expvals)
3363
return nothing
3464
end
3565

36-
function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
66+
function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rng, kwargs)
3767
c_ops_data = get_data.(c_ops)
3868
c_ops_herm_data = map(op -> op' * op, c_ops_data)
3969

40-
_affect! = LindbladJump(c_ops_data, c_ops_herm_data)
70+
cache_mc = similar(ψ0.data, T)
71+
weights_mc = Vector{Float64}(undef, length(c_ops))
72+
cumsum_weights_mc = similar(weights_mc)
73+
74+
jump_times = Vector{Float64}(undef, JUMP_TIMES_WHICH_INIT_SIZE)
75+
jump_which = Vector{Int}(undef, JUMP_TIMES_WHICH_INIT_SIZE)
76+
jump_times_which_idx = Ref(1)
77+
78+
random_n = Ref(rand(rng))
79+
80+
_affect! = LindbladJump(
81+
c_ops_data,
82+
c_ops_herm_data,
83+
rng,
84+
random_n,
85+
cache_mc,
86+
weights_mc,
87+
cumsum_weights_mc,
88+
jump_times,
89+
jump_which,
90+
jump_times_which_idx,
91+
)
4192

4293
if jump_callback isa DiscreteLindbladJumpCallback
4394
cb1 = DiscreteCallback(_mcsolve_discrete_condition, _affect!, save_positions = (false, false))
@@ -69,35 +120,38 @@ function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
69120
end
70121
end
71122

72-
function _lindblad_jump_affect!(integrator, c_ops, c_ops_herm)
73-
params = integrator.p
74-
cache_mc = params.mcsolve_params.cache_mc
75-
weights_mc = params.mcsolve_params.weights_mc
76-
cumsum_weights_mc = params.mcsolve_params.cumsum_weights_mc
77-
random_n = params.mcsolve_params.random_n
78-
jump_times = params.mcsolve_params.jump_times
79-
jump_which = params.mcsolve_params.jump_which
80-
jump_times_which_idx = params.mcsolve_params.jump_times_which_idx
81-
traj_rng = params.mcsolve_params.traj_rng
123+
function _lindblad_jump_affect!(
124+
integrator,
125+
c_ops,
126+
c_ops_herm,
127+
traj_rng,
128+
random_n,
129+
cache_mc,
130+
weights_mc,
131+
cumsum_weights_mc,
132+
jump_times,
133+
jump_which,
134+
jump_times_which_idx,
135+
)
82136
ψ = integrator.u
83137

84138
@inbounds for i in eachindex(weights_mc)
85139
weights_mc[i] = real(dot(ψ, c_ops_herm[i], ψ))
86140
end
87141
cumsum!(cumsum_weights_mc, weights_mc)
88-
r = rand(traj_rng) * sum(real, weights_mc)
89-
collapse_idx = getindex(1:length(weights_mc), findfirst(x -> real(x) > r, cumsum_weights_mc))
142+
r = rand(traj_rng) * sum(weights_mc)
143+
collapse_idx = getindex(1:length(weights_mc), findfirst(>(r), cumsum_weights_mc))
90144
mul!(cache_mc, c_ops[collapse_idx], ψ)
91145
normalize!(cache_mc)
92146
copyto!(integrator.u, cache_mc)
93147

94-
@inbounds random_n[1] = rand(traj_rng)
148+
random_n[] = rand(traj_rng)
95149

96-
@inbounds idx = round(Int, real(jump_times_which_idx[1]))
150+
idx = jump_times_which_idx[]
97151
@inbounds jump_times[idx] = integrator.t
98152
@inbounds jump_which[idx] = collapse_idx
99-
@inbounds jump_times_which_idx[1] += 1
100-
@inbounds if real(jump_times_which_idx[1]) > length(jump_times)
153+
jump_times_which_idx[] += 1
154+
if jump_times_which_idx[] > length(jump_times)
101155
resize!(jump_times, length(jump_times) + JUMP_TIMES_WHICH_INIT_SIZE)
102156
resize!(jump_which, length(jump_which) + JUMP_TIMES_WHICH_INIT_SIZE)
103157
end
@@ -106,89 +160,181 @@ function _lindblad_jump_affect!(integrator, c_ops, c_ops_herm)
106160
end
107161

108162
_mcsolve_continuous_condition(u, t, integrator) =
109-
@inbounds real(integrator.p.mcsolve_params.random_n[1]) - real(dot(u, u))
163+
@inbounds _mc_get_jump_callback(integrator).affect!.random_n[] - real(dot(u, u))
110164

111165
_mcsolve_discrete_condition(u, t, integrator) =
112-
@inbounds real(dot(u, u)) < real(integrator.p.mcsolve_params.random_n[1])
166+
@inbounds real(dot(u, u)) < _mc_get_jump_callback(integrator).affect!.random_n[]
167+
168+
##
169+
170+
#=
171+
_mc_get_save_callback
172+
173+
Return the Callback that is responsible for saving the expectation values of the system.
174+
=#
175+
function _mc_get_save_callback(sol::AbstractODESolution)
176+
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
177+
return _mc_get_save_callback(kwargs.callback) # There is always the Jump callback
178+
end
179+
_mc_get_save_callback(integrator::AbstractODEIntegrator) = _mc_get_save_callback(integrator.opts.callback)
180+
function _mc_get_save_callback(cb::CallbackSet)
181+
cbs_discrete = cb.discrete_callbacks
182+
183+
if length(cbs_discrete) > 0
184+
idx = _mcsolve_has_continuous_jump(cb) ? 1 : 2
185+
_cb = cb.discrete_callbacks[idx]
186+
return _mc_get_save_callback(_cb)
187+
else
188+
return nothing
189+
end
190+
end
191+
_mc_get_save_callback(cb::DiscreteCallback) =
192+
if cb.affect! isa SaveFuncMCSolve
193+
return cb
194+
else
195+
return nothing
196+
end
197+
_mc_get_save_callback(cb::ContinuousCallback) = nothing
198+
199+
##
200+
201+
function _mc_get_jump_callback(sol::AbstractODESolution)
202+
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
203+
return _mc_get_jump_callback(kwargs.callback) # There is always the Jump callback
204+
end
205+
_mc_get_jump_callback(integrator::AbstractODEIntegrator) = _mc_get_jump_callback(integrator.opts.callback)
206+
_mc_get_jump_callback(cb::CallbackSet) =
207+
if _mcsolve_has_continuous_jump(cb)
208+
return cb.continuous_callbacks[1]
209+
else
210+
return cb.discrete_callbacks[1]
211+
end
212+
_mc_get_jump_callback(cb::ContinuousCallback) = cb
213+
_mc_get_jump_callback(cb::DiscreteCallback) = cb
214+
215+
##
113216

114217
#=
115218
With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
116219
This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
117220
=#
118221
function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator)
119-
cb_set = integrator.opts.callback # This is supposed to be a CallbackSet
120-
(cb_set isa CallbackSet) || throw(ArgumentError("The callback must be a CallbackSet."))
121-
cb = isempty(cb_set.continuous_callbacks) ? cb_set.discrete_callback[1] : cb_set.continuous_callbacks[1]
122-
return cb.affect!.c_ops, cb.affect!.c_ops_herm
222+
cb = _mc_get_jump_callback(integrator)
223+
if cb isa Nothing
224+
return nothing
225+
else
226+
return cb.affect!.c_ops, cb.affect!.c_ops_herm
227+
end
123228
end
124229

125230
#=
126231
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
127232
This callback can only be a PresetTimeCallback (DiscreteCallback).
128233
=#
129234
function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator)
130-
cb_set = integrator.opts.callback # This is supposed to be a CallbackSet
131-
(cb_set isa CallbackSet) || throw(ArgumentError("The callback must be a CallbackSet."))
132-
cb = length(cb_set.continuous_callbacks) > 0 ? cb_set.discrete_callbacks[1] : cb_set.discrete_callbacks[2]
133-
return cb.affect!.e_ops
235+
cb = _mc_get_save_callback(integrator)
236+
if cb isa Nothing
237+
return nothing
238+
else
239+
return cb.affect!.e_ops
240+
end
134241
end
135242

136243
function _mcsolve_get_expvals(sol::AbstractODESolution)
137-
cb = NamedTuple(sol.prob.kwargs).callback
138-
if _mcsolve_has_discrete_callbacks(cb)
139-
return _mcsolve_get_expvals(cb)
140-
else
244+
cb = _mc_get_save_callback(sol)
245+
if cb isa Nothing
141246
return nothing
142-
end
143-
end
144-
function _mcsolve_get_expvals(cb::CallbackSet)
145-
idx = _mcsolve_has_continuous_jump(cb) ? 1 : 2
146-
_cb = cb.discrete_callbacks[idx]
147-
return _mcsolve_get_expvals(_cb)
148-
end
149-
_mcsolve_get_expvals(cb::DiscreteCallback) =
150-
if cb.affect! isa SaveFuncMCSolve
151-
return cb.affect!.expvals
152247
else
153-
nothing
248+
return cb.affect!.expvals
154249
end
155-
_mcsolve_get_expvals(cb::ContinuousCallback) = nothing
250+
end
156251

157252
#=
158-
_mcsolve_callbacks_new_iter_expvals(prob, tlist)
253+
_mcsolve_initialize_callbacks(prob, tlist)
159254
160255
Return the same callbacks of the `prob`, but with the `iter` variable reinitialized to 1 and the `expvals` variable reinitialized to a new matrix.
161256
=#
162-
function _mcsolve_callbacks_new_iter_expvals(prob, tlist)
257+
function _mcsolve_initialize_callbacks(prob, tlist, traj_rng)
163258
cb = prob.kwargs[:callback]
164-
return _mcsolve_callbacks_new_iter_expvals(cb, tlist)
259+
return _mcsolve_initialize_callbacks(cb, tlist, traj_rng)
165260
end
166-
function _mcsolve_callbacks_new_iter_expvals(cb::CallbackSet, tlist)
261+
function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng)
167262
cb_continuous = cb.continuous_callbacks
168263
cb_discrete = cb.discrete_callbacks
169264

170265
if _mcsolve_has_continuous_jump(cb)
171266
idx = 1
172-
e_ops = cb_discrete[idx].affect!.e_ops
173-
expvals = similar(cb_discrete[idx].affect!.expvals)
174-
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
175-
cb_save = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
176-
return CallbackSet(cb_continuous..., cb_save, cb_discrete[2:end]...)
267+
if cb_discrete[idx].affect! isa SaveFuncMCSolve
268+
e_ops = cb_discrete[idx].affect!.e_ops
269+
expvals = similar(cb_discrete[idx].affect!.expvals)
270+
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
271+
cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),)
272+
else
273+
cb_save = ()
274+
end
275+
276+
_jump_affect! = _similar_affect!(cb_continuous[1].affect!, traj_rng)
277+
cb_jump = _modify_field(cb_continuous[1], :affect!, _jump_affect!)
278+
279+
return CallbackSet((cb_jump, cb_continuous[2:end]...), (cb_save..., cb_discrete[2:end]...))
177280
else
178281
idx = 2
179-
e_ops = cb_discrete[idx].affect!.e_ops
180-
expvals = similar(cb_discrete[idx].affect!.expvals)
181-
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
182-
cb_save = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
183-
return CallbackSet(cb_continuous..., cb_discrete[1], cb_save, cb_discrete[3:end]...)
282+
if cb_discrete[idx].affect! isa SaveFuncMCSolve
283+
e_ops = cb_discrete[idx].affect!.e_ops
284+
expvals = similar(cb_discrete[idx].affect!.expvals)
285+
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
286+
cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),)
287+
else
288+
cb_save = ()
289+
end
290+
291+
_jump_affect! = _similar_affect!(cb_discrete[1].affect!, traj_rng)
292+
cb_jump = _modify_field(cb_discrete[1], :affect!, _jump_affect!)
293+
294+
return CallbackSet(cb_continuous, (cb_jump, cb_save..., cb_discrete[3:end]...))
184295
end
185296
end
186-
_mcsolve_callbacks_new_iter_expvals(cb::ContinuousCallback, tlist) = cb # It is only the continuous LindbladJump callback
187-
_mcsolve_callbacks_new_iter_expvals(cb::DiscreteCallback, tlist) = cb # It is only the discrete LindbladJump callback
297+
# _mcsolve_initialize_callbacks(cb::ContinuousCallback, tlist) = cb # It is only the continuous LindbladJump callback
298+
# _mcsolve_initialize_callbacks(cb::DiscreteCallback, tlist) = cb # It is only the discrete LindbladJump callback
299+
function _mcsolve_initialize_callbacks(cb::CBT, tlist, traj_rng) where {CBT<:Union{ContinuousCallback,DiscreteCallback}}
300+
_jump_affect! = _similar_affect!(cb.affect!, traj_rng)
301+
return _modify_field(cb, :affect!, _jump_affect!)
302+
end
188303

189-
_mcsolve_has_discrete_callbacks(cb::CallbackSet) = length(cb.discrete_callbacks) > 0
190-
_mcsolve_has_discrete_callbacks(cb::ContinuousCallback) = false
191-
_mcsolve_has_discrete_callbacks(cb::DiscreteCallback) = true
304+
#=
305+
_similar_affect!
306+
307+
Return a new LindbladJump with the same fields as the input LindbladJump but with new memory.
308+
=#
309+
function _similar_affect!(affect::LindbladJump, traj_rng)
310+
random_n = Ref(rand(traj_rng))
311+
cache_mc = similar(affect.cache_mc)
312+
weights_mc = similar(affect.weights_mc)
313+
cumsum_weights_mc = similar(affect.cumsum_weights_mc)
314+
jump_times = similar(affect.jump_times)
315+
jump_which = similar(affect.jump_which)
316+
jump_times_which_idx = Ref(1)
317+
318+
return LindbladJump(
319+
affect.c_ops,
320+
affect.c_ops_herm,
321+
traj_rng,
322+
random_n,
323+
cache_mc,
324+
weights_mc,
325+
cumsum_weights_mc,
326+
jump_times,
327+
jump_which,
328+
jump_times_which_idx,
329+
)
330+
end
331+
332+
function _modify_field(obj::T, field_name::Symbol, field_val) where {T}
333+
# Create a NamedTuple of fields, deepcopying only the selected ones
334+
fields = (name != field_name ? (getfield(obj, name)) : field_val for name in fieldnames(T))
335+
# Reconstruct the struct with the updated fields
336+
return Base.typename(T).wrapper(fields...)
337+
end
192338

193339
_mcsolve_has_continuous_jump(cb::CallbackSet) =
194340
(length(cb.continuous_callbacks) > 0) && (cb.continuous_callbacks[1].affect! isa LindbladJump)

0 commit comments

Comments
 (0)