@@ -4,13 +4,14 @@ This file contains helper functions for callbacks. The affect! function are defi
44
55# ######### SESOLVE ##########
66
7- struct SaveFuncSESolve{TE,PT<: Union{Nothing,ProgressBar} ,IT}
7+ struct SaveFuncSESolve{TE,PT<: Union{Nothing,ProgressBar} ,IT,TEXPV <: Union{Nothing,AbstractMatrix} }
88 e_ops:: TE
99 progr:: PT
1010 iter:: IT
11+ expvals:: TEXPV
1112end
1213
13- (f:: SaveFuncSESolve )(integrator) = _save_func_sesolve (integrator, f. e_ops, f. progr, f. iter)
14+ (f:: SaveFuncSESolve )(integrator) = _save_func_sesolve (integrator, f. e_ops, f. progr, f. iter, f . expvals )
1415(f:: SaveFuncSESolve{Nothing} )(integrator) = _save_func_sesolve (integrator, f. progr)
1516
1617# #
@@ -29,8 +30,7 @@ function _save_func_sesolve(integrator, progr::Nothing)
2930end
3031
3132# When e_ops is a list of operators
32- function _save_func_sesolve (integrator, e_ops, progr, iter)
33- expvals = integrator. p. expvals
33+ function _save_func_sesolve (integrator, e_ops, progr, iter, expvals)
3434 ψ = integrator. u
3535 _expect = op -> dot (ψ, op, ψ)
3636 @. expvals[:, iter[]] = _expect (e_ops)
@@ -44,18 +44,40 @@ function _generate_sesolve_callback(e_ops, tlist, progress_bar)
4444
4545 progr = getVal (progress_bar) ? ProgressBar (length (tlist), enable = getVal (progress_bar)) : nothing
4646
47- _save_affect! = SaveFuncSESolve (e_ops_data, progr, Ref (1 ))
48- return PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
47+ expvals = e_ops isa Nothing ? nothing : Array {ComplexF64} (undef, length (e_ops), length (tlist))
48+
49+ _save_affect! = SaveFuncSESolve (e_ops_data, progr, Ref (1 ), expvals)
50+ return _PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
4951end
5052
53+ function _sesolve_get_expvals (sol:: AbstractODESolution )
54+ kwargs = NamedTuple (sol. prob. kwargs) # Convert to NamedTuple to support Zygote.jl
55+ if hasproperty (kwargs, :callback )
56+ return _sesolve_get_expvals (kwargs. callback)
57+ else
58+ return nothing
59+ end
60+ end
61+ function _sesolve_get_expvals (cb:: CallbackSet )
62+ _cb = cb. discrete_callbacks[1 ]
63+ return _sesolve_get_expvals (_cb)
64+ end
65+ _sesolve_get_expvals (cb:: DiscreteCallback ) = if cb. affect! isa SaveFuncSESolve
66+ return cb. affect!. expvals
67+ else
68+ return nothing
69+ end
70+ _sesolve_get_expvals (cb:: ContinuousCallback ) = nothing
71+
5172# ######### MCSOLVE ##########
5273
53- struct SaveFuncMCSolve{TE,IT}
74+ struct SaveFuncMCSolve{TE,IT,TEXPV }
5475 e_ops:: TE
5576 iter:: IT
77+ expvals:: TEXPV
5678end
5779
58- (f:: SaveFuncMCSolve )(integrator) = _save_func_mcsolve (integrator, f. e_ops, f. iter)
80+ (f:: SaveFuncMCSolve )(integrator) = _save_func_mcsolve (integrator, f. e_ops, f. iter, f . expvals )
5981
6082struct LindbladJump{T1,T2}
6183 c_ops:: T1
6688
6789# #
6890
69- function _save_func_mcsolve (integrator, e_ops, iter)
70- expvals = integrator. p. expvals
91+ function _save_func_mcsolve (integrator, e_ops, iter, expvals)
7192 cache_mc = integrator. p. mcsolve_params. cache_mc
7293
7394 copyto! (cache_mc, integrator. u)
@@ -100,12 +121,15 @@ function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
100121 end
101122
102123 if e_ops isa Nothing
124+ # We are implicitly saying that we don't have a `ProgressBar`
103125 kwargs2 =
104126 haskey (kwargs, :callback ) ? merge (kwargs, (callback = CallbackSet (cb1, kwargs. callback),)) :
105127 merge (kwargs, (callback = cb1,))
106128 return kwargs2
107129 else
108- _save_affect! = SaveFuncMCSolve (get_data .(e_ops), Ref (1 ))
130+ expvals = Array {ComplexF64} (undef, length (e_ops), length (tlist))
131+
132+ _save_affect! = SaveFuncMCSolve (get_data .(e_ops), Ref (1 ), expvals)
109133 cb2 = _PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
110134 kwargs2 =
111135 haskey (kwargs, :callback ) ? merge (kwargs, (callback = CallbackSet (cb1, cb2, kwargs. callback),)) :
@@ -178,35 +202,67 @@ function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator)
178202 return cb. affect!. e_ops
179203end
180204
205+ function _mcsolve_get_expvals (sol:: AbstractODESolution )
206+ cb = NamedTuple (sol. prob. kwargs). callback
207+ if _mcsolve_has_discrete_callbacks (cb)
208+ return _mcsolve_get_expvals (cb)
209+ else
210+ return nothing
211+ end
212+ end
213+ function _mcsolve_get_expvals (cb:: CallbackSet )
214+ idx = _mcsolve_has_continuous_jump (cb) ? 1 : 2
215+ _cb = cb. discrete_callbacks[idx]
216+ return _mcsolve_get_expvals (_cb)
217+ end
218+ _mcsolve_get_expvals (cb:: DiscreteCallback ) =
219+ if cb. affect! isa SaveFuncMCSolve
220+ return cb. affect!. expvals
221+ else
222+ nothing
223+ end
224+ _mcsolve_get_expvals (cb:: ContinuousCallback ) = nothing
225+
181226#=
182- _mcsolve_callbacks_new_iter (prob, tlist)
227+ _mcsolve_callbacks_new_iter_expvals (prob, tlist)
183228
184- Return the same callbacks of the `prob`, but with the `iter` variable reinitialized to 1.
229+ Return the same callbacks of the `prob`, but with the `iter` variable reinitialized to 1 and the `expvals` variable reinitialized to a new matrix .
185230=#
186- function _mcsolve_callbacks_new_iter (prob, tlist)
231+ function _mcsolve_callbacks_new_iter_expvals (prob, tlist)
187232 cb = prob. kwargs[:callback ]
188- return _mcsolve_callbacks_new_iter (cb, tlist)
233+ return _mcsolve_callbacks_new_iter_expvals (cb, tlist)
189234end
190- function _mcsolve_callbacks_new_iter (cb:: CallbackSet , tlist)
235+ function _mcsolve_callbacks_new_iter_expvals (cb:: CallbackSet , tlist)
191236 cb_continuous = cb. continuous_callbacks
192237 cb_discrete = cb. discrete_callbacks
193238
194- if length (cb_continuous) > 0
239+ if _mcsolve_has_continuous_jump (cb)
195240 idx = 1
196241 e_ops = cb_discrete[idx]. affect!. e_ops
197- _save_affect! = SaveFuncMCSolve (e_ops, Ref (1 ))
242+ expvals = similar (cb_discrete[idx]. affect!. expvals)
243+ _save_affect! = SaveFuncMCSolve (e_ops, Ref (1 ), expvals)
198244 cb_save = _PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
199245 return CallbackSet (cb_continuous... , cb_save, cb_discrete[2 : end ]. .. )
200246 else
201247 idx = 2
202248 e_ops = cb_discrete[idx]. affect!. e_ops
203- _save_affect! = SaveFuncMCSolve (e_ops, Ref (1 ))
249+ expvals = similar (cb_discrete[idx]. affect!. expvals)
250+ _save_affect! = SaveFuncMCSolve (e_ops, Ref (1 ), expvals)
204251 cb_save = _PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
205252 return CallbackSet (cb_continuous... , cb_discrete[1 ], cb_save, cb_discrete[3 : end ]. .. )
206253 end
207254end
208- _mcsolve_callbacks_new_iter (cb:: ContinuousCallback , tlist) = cb
209- _mcsolve_callbacks_new_iter (cb:: DiscreteCallback , tlist) = cb
255+ _mcsolve_callbacks_new_iter_expvals (cb:: ContinuousCallback , tlist) = cb # It is only the continuous LindbladJump callback
256+ _mcsolve_callbacks_new_iter_expvals (cb:: DiscreteCallback , tlist) = cb # It is only the discrete LindbladJump callback
257+
258+ _mcsolve_has_discrete_callbacks (cb:: CallbackSet ) = length (cb. discrete_callbacks) > 0
259+ _mcsolve_has_discrete_callbacks (cb:: ContinuousCallback ) = false
260+ _mcsolve_has_discrete_callbacks (cb:: DiscreteCallback ) = true
261+
262+ _mcsolve_has_continuous_jump (cb:: CallbackSet ) =
263+ (length (cb. continuous_callbacks) > 0 ) && (cb. continuous_callbacks[1 ]. affect! isa LindbladJump)
264+ _mcsolve_has_continuous_jump (cb:: ContinuousCallback ) = true
265+ _mcsolve_has_continuous_jump (cb:: DiscreteCallback ) = false
210266
211267# # Temporary function to avoid errors. Waiting for the PR In DiffEqCallbacks.jl to be merged.
212268
0 commit comments