1010
1111(f:: SaveFuncMCSolve )(integrator) = _save_func_mcsolve (integrator, f. e_ops, f. iter, f. expvals)
1212
13- struct LindbladJump{T1,T2}
13+ struct LindbladJump{
14+ T1,
15+ T2,
16+ RNGType<: AbstractRNG ,
17+ RandT,
18+ CT<: AbstractVector ,
19+ WT<: AbstractVector ,
20+ JTT<: AbstractVector ,
21+ JWT<: AbstractVector ,
22+ JTWIT,
23+ }
1424 c_ops:: T1
1525 c_ops_herm:: T2
26+ traj_rng:: RNGType
27+ random_n:: RandT
28+ cache_mc:: CT
29+ weights_mc:: WT
30+ cumsum_weights_mc:: WT
31+ jump_times:: JTT
32+ jump_which:: JWT
33+ jump_times_which_idx:: JTWIT
1634end
1735
18- (f:: LindbladJump )(integrator) = _lindblad_jump_affect! (integrator, f. c_ops, f. c_ops_herm)
36+ (f:: LindbladJump )(integrator) = _lindblad_jump_affect! (
37+ integrator,
38+ f. c_ops,
39+ f. c_ops_herm,
40+ f. traj_rng,
41+ f. random_n,
42+ f. cache_mc,
43+ f. weights_mc,
44+ f. cumsum_weights_mc,
45+ f. jump_times,
46+ f. jump_which,
47+ f. jump_times_which_idx,
48+ )
1949
2050# #
2151
2252function _save_func_mcsolve (integrator, e_ops, iter, expvals)
23- cache_mc = integrator. p . mcsolve_params . cache_mc
53+ cache_mc = _mc_get_jump_callback ( integrator) . affect! . cache_mc
2454
2555 copyto! (cache_mc, integrator. u)
2656 normalize! (cache_mc)
@@ -33,11 +63,32 @@ function _save_func_mcsolve(integrator, e_ops, iter, expvals)
3363 return nothing
3464end
3565
36- function _generate_mcsolve_kwargs (e_ops, tlist, c_ops, jump_callback, kwargs)
66+ function _generate_mcsolve_kwargs (ψ0, T, e_ops, tlist, c_ops, jump_callback, rng , kwargs)
3767 c_ops_data = get_data .(c_ops)
3868 c_ops_herm_data = map (op -> op' * op, c_ops_data)
3969
40- _affect! = LindbladJump (c_ops_data, c_ops_herm_data)
70+ cache_mc = similar (ψ0. data, T)
71+ weights_mc = Vector {Float64} (undef, length (c_ops))
72+ cumsum_weights_mc = similar (weights_mc)
73+
74+ jump_times = Vector {Float64} (undef, JUMP_TIMES_WHICH_INIT_SIZE)
75+ jump_which = Vector {Int} (undef, JUMP_TIMES_WHICH_INIT_SIZE)
76+ jump_times_which_idx = Ref (1 )
77+
78+ random_n = Ref (rand (rng))
79+
80+ _affect! = LindbladJump (
81+ c_ops_data,
82+ c_ops_herm_data,
83+ rng,
84+ random_n,
85+ cache_mc,
86+ weights_mc,
87+ cumsum_weights_mc,
88+ jump_times,
89+ jump_which,
90+ jump_times_which_idx,
91+ )
4192
4293 if jump_callback isa DiscreteLindbladJumpCallback
4394 cb1 = DiscreteCallback (_mcsolve_discrete_condition, _affect!, save_positions = (false , false ))
@@ -69,35 +120,38 @@ function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
69120 end
70121end
71122
72- function _lindblad_jump_affect! (integrator, c_ops, c_ops_herm)
73- params = integrator. p
74- cache_mc = params. mcsolve_params. cache_mc
75- weights_mc = params. mcsolve_params. weights_mc
76- cumsum_weights_mc = params. mcsolve_params. cumsum_weights_mc
77- random_n = params. mcsolve_params. random_n
78- jump_times = params. mcsolve_params. jump_times
79- jump_which = params. mcsolve_params. jump_which
80- jump_times_which_idx = params. mcsolve_params. jump_times_which_idx
81- traj_rng = params. mcsolve_params. traj_rng
123+ function _lindblad_jump_affect! (
124+ integrator,
125+ c_ops,
126+ c_ops_herm,
127+ traj_rng,
128+ random_n,
129+ cache_mc,
130+ weights_mc,
131+ cumsum_weights_mc,
132+ jump_times,
133+ jump_which,
134+ jump_times_which_idx,
135+ )
82136 ψ = integrator. u
83137
84138 @inbounds for i in eachindex (weights_mc)
85139 weights_mc[i] = real (dot (ψ, c_ops_herm[i], ψ))
86140 end
87141 cumsum! (cumsum_weights_mc, weights_mc)
88- r = rand (traj_rng) * sum (real, weights_mc)
89- collapse_idx = getindex (1 : length (weights_mc), findfirst (x -> real (x) > r , cumsum_weights_mc))
142+ r = rand (traj_rng) * sum (weights_mc)
143+ collapse_idx = getindex (1 : length (weights_mc), findfirst (> (r) , cumsum_weights_mc))
90144 mul! (cache_mc, c_ops[collapse_idx], ψ)
91145 normalize! (cache_mc)
92146 copyto! (integrator. u, cache_mc)
93147
94- @inbounds random_n[1 ] = rand (traj_rng)
148+ random_n[] = rand (traj_rng)
95149
96- @inbounds idx = round (Int, real ( jump_times_which_idx[1 ]))
150+ idx = jump_times_which_idx[]
97151 @inbounds jump_times[idx] = integrator. t
98152 @inbounds jump_which[idx] = collapse_idx
99- @inbounds jump_times_which_idx[1 ] += 1
100- @inbounds if real ( jump_times_which_idx[1 ]) > length (jump_times)
153+ jump_times_which_idx[] += 1
154+ if jump_times_which_idx[] > length (jump_times)
101155 resize! (jump_times, length (jump_times) + JUMP_TIMES_WHICH_INIT_SIZE)
102156 resize! (jump_which, length (jump_which) + JUMP_TIMES_WHICH_INIT_SIZE)
103157 end
@@ -106,89 +160,181 @@ function _lindblad_jump_affect!(integrator, c_ops, c_ops_herm)
106160end
107161
108162_mcsolve_continuous_condition (u, t, integrator) =
109- @inbounds real (integrator. p . mcsolve_params . random_n[1 ]) - real (dot (u, u))
163+ @inbounds _mc_get_jump_callback (integrator) . affect! . random_n[] - real (dot (u, u))
110164
111165_mcsolve_discrete_condition (u, t, integrator) =
112- @inbounds real (dot (u, u)) < real (integrator. p. mcsolve_params. random_n[1 ])
166+ @inbounds real (dot (u, u)) < _mc_get_jump_callback (integrator). affect!. random_n[]
167+
168+ # #
169+
170+ #=
171+ _mc_get_save_callback
172+
173+ Return the Callback that is responsible for saving the expectation values of the system.
174+ =#
175+ function _mc_get_save_callback (sol:: AbstractODESolution )
176+ kwargs = NamedTuple (sol. prob. kwargs) # Convert to NamedTuple to support Zygote.jl
177+ return _mc_get_save_callback (kwargs. callback) # There is always the Jump callback
178+ end
179+ _mc_get_save_callback (integrator:: AbstractODEIntegrator ) = _mc_get_save_callback (integrator. opts. callback)
180+ function _mc_get_save_callback (cb:: CallbackSet )
181+ cbs_discrete = cb. discrete_callbacks
182+
183+ if length (cbs_discrete) > 0
184+ idx = _mcsolve_has_continuous_jump (cb) ? 1 : 2
185+ _cb = cb. discrete_callbacks[idx]
186+ return _mc_get_save_callback (_cb)
187+ else
188+ return nothing
189+ end
190+ end
191+ _mc_get_save_callback (cb:: DiscreteCallback ) =
192+ if cb. affect! isa SaveFuncMCSolve
193+ return cb
194+ else
195+ return nothing
196+ end
197+ _mc_get_save_callback (cb:: ContinuousCallback ) = nothing
198+
199+ # #
200+
201+ function _mc_get_jump_callback (sol:: AbstractODESolution )
202+ kwargs = NamedTuple (sol. prob. kwargs) # Convert to NamedTuple to support Zygote.jl
203+ return _mc_get_jump_callback (kwargs. callback) # There is always the Jump callback
204+ end
205+ _mc_get_jump_callback (integrator:: AbstractODEIntegrator ) = _mc_get_jump_callback (integrator. opts. callback)
206+ _mc_get_jump_callback (cb:: CallbackSet ) =
207+ if _mcsolve_has_continuous_jump (cb)
208+ return cb. continuous_callbacks[1 ]
209+ else
210+ return cb. discrete_callbacks[1 ]
211+ end
212+ _mc_get_jump_callback (cb:: ContinuousCallback ) = cb
213+ _mc_get_jump_callback (cb:: DiscreteCallback ) = cb
214+
215+ # #
113216
114217#=
115218With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
116219This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
117220=#
118221function _mcsolve_get_c_ops (integrator:: AbstractODEIntegrator )
119- cb_set = integrator. opts. callback # This is supposed to be a CallbackSet
120- (cb_set isa CallbackSet) || throw (ArgumentError (" The callback must be a CallbackSet." ))
121- cb = isempty (cb_set. continuous_callbacks) ? cb_set. discrete_callback[1 ] : cb_set. continuous_callbacks[1 ]
122- return cb. affect!. c_ops, cb. affect!. c_ops_herm
222+ cb = _mc_get_jump_callback (integrator)
223+ if cb isa Nothing
224+ return nothing
225+ else
226+ return cb. affect!. c_ops, cb. affect!. c_ops_herm
227+ end
123228end
124229
125230#=
126231With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
127232This callback can only be a PresetTimeCallback (DiscreteCallback).
128233=#
129234function _mcsolve_get_e_ops (integrator:: AbstractODEIntegrator )
130- cb_set = integrator. opts. callback # This is supposed to be a CallbackSet
131- (cb_set isa CallbackSet) || throw (ArgumentError (" The callback must be a CallbackSet." ))
132- cb = length (cb_set. continuous_callbacks) > 0 ? cb_set. discrete_callbacks[1 ] : cb_set. discrete_callbacks[2 ]
133- return cb. affect!. e_ops
235+ cb = _mc_get_save_callback (integrator)
236+ if cb isa Nothing
237+ return nothing
238+ else
239+ return cb. affect!. e_ops
240+ end
134241end
135242
136243function _mcsolve_get_expvals (sol:: AbstractODESolution )
137- cb = NamedTuple (sol. prob. kwargs). callback
138- if _mcsolve_has_discrete_callbacks (cb)
139- return _mcsolve_get_expvals (cb)
140- else
244+ cb = _mc_get_save_callback (sol)
245+ if cb isa Nothing
141246 return nothing
142- end
143- end
144- function _mcsolve_get_expvals (cb:: CallbackSet )
145- idx = _mcsolve_has_continuous_jump (cb) ? 1 : 2
146- _cb = cb. discrete_callbacks[idx]
147- return _mcsolve_get_expvals (_cb)
148- end
149- _mcsolve_get_expvals (cb:: DiscreteCallback ) =
150- if cb. affect! isa SaveFuncMCSolve
151- return cb. affect!. expvals
152247 else
153- nothing
248+ return cb . affect! . expvals
154249 end
155- _mcsolve_get_expvals (cb :: ContinuousCallback ) = nothing
250+ end
156251
157252#=
158- _mcsolve_callbacks_new_iter_expvals (prob, tlist)
253+ _mcsolve_initialize_callbacks (prob, tlist)
159254
160255Return the same callbacks of the `prob`, but with the `iter` variable reinitialized to 1 and the `expvals` variable reinitialized to a new matrix.
161256=#
162- function _mcsolve_callbacks_new_iter_expvals (prob, tlist)
257+ function _mcsolve_initialize_callbacks (prob, tlist, traj_rng )
163258 cb = prob. kwargs[:callback ]
164- return _mcsolve_callbacks_new_iter_expvals (cb, tlist)
259+ return _mcsolve_initialize_callbacks (cb, tlist, traj_rng )
165260end
166- function _mcsolve_callbacks_new_iter_expvals (cb:: CallbackSet , tlist)
261+ function _mcsolve_initialize_callbacks (cb:: CallbackSet , tlist, traj_rng )
167262 cb_continuous = cb. continuous_callbacks
168263 cb_discrete = cb. discrete_callbacks
169264
170265 if _mcsolve_has_continuous_jump (cb)
171266 idx = 1
172- e_ops = cb_discrete[idx]. affect!. e_ops
173- expvals = similar (cb_discrete[idx]. affect!. expvals)
174- _save_affect! = SaveFuncMCSolve (e_ops, Ref (1 ), expvals)
175- cb_save = PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
176- return CallbackSet (cb_continuous... , cb_save, cb_discrete[2 : end ]. .. )
267+ if cb_discrete[idx]. affect! isa SaveFuncMCSolve
268+ e_ops = cb_discrete[idx]. affect!. e_ops
269+ expvals = similar (cb_discrete[idx]. affect!. expvals)
270+ _save_affect! = SaveFuncMCSolve (e_ops, Ref (1 ), expvals)
271+ cb_save = (PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false )),)
272+ else
273+ cb_save = ()
274+ end
275+
276+ _jump_affect! = _similar_affect! (cb_continuous[1 ]. affect!, traj_rng)
277+ cb_jump = _modify_field (cb_continuous[1 ], :affect! , _jump_affect!)
278+
279+ return CallbackSet ((cb_jump, cb_continuous[2 : end ]. .. ), (cb_save... , cb_discrete[2 : end ]. .. ))
177280 else
178281 idx = 2
179- e_ops = cb_discrete[idx]. affect!. e_ops
180- expvals = similar (cb_discrete[idx]. affect!. expvals)
181- _save_affect! = SaveFuncMCSolve (e_ops, Ref (1 ), expvals)
182- cb_save = PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
183- return CallbackSet (cb_continuous... , cb_discrete[1 ], cb_save, cb_discrete[3 : end ]. .. )
282+ if cb_discrete[idx]. affect! isa SaveFuncMCSolve
283+ e_ops = cb_discrete[idx]. affect!. e_ops
284+ expvals = similar (cb_discrete[idx]. affect!. expvals)
285+ _save_affect! = SaveFuncMCSolve (e_ops, Ref (1 ), expvals)
286+ cb_save = (PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false )),)
287+ else
288+ cb_save = ()
289+ end
290+
291+ _jump_affect! = _similar_affect! (cb_discrete[1 ]. affect!, traj_rng)
292+ cb_jump = _modify_field (cb_discrete[1 ], :affect! , _jump_affect!)
293+
294+ return CallbackSet (cb_continuous, (cb_jump, cb_save... , cb_discrete[3 : end ]. .. ))
184295 end
185296end
186- _mcsolve_callbacks_new_iter_expvals (cb:: ContinuousCallback , tlist) = cb # It is only the continuous LindbladJump callback
187- _mcsolve_callbacks_new_iter_expvals (cb:: DiscreteCallback , tlist) = cb # It is only the discrete LindbladJump callback
297+ # _mcsolve_initialize_callbacks(cb::ContinuousCallback, tlist) = cb # It is only the continuous LindbladJump callback
298+ # _mcsolve_initialize_callbacks(cb::DiscreteCallback, tlist) = cb # It is only the discrete LindbladJump callback
299+ function _mcsolve_initialize_callbacks (cb:: CBT , tlist, traj_rng) where {CBT<: Union{ContinuousCallback,DiscreteCallback} }
300+ _jump_affect! = _similar_affect! (cb. affect!, traj_rng)
301+ return _modify_field (cb, :affect! , _jump_affect!)
302+ end
188303
189- _mcsolve_has_discrete_callbacks (cb:: CallbackSet ) = length (cb. discrete_callbacks) > 0
190- _mcsolve_has_discrete_callbacks (cb:: ContinuousCallback ) = false
191- _mcsolve_has_discrete_callbacks (cb:: DiscreteCallback ) = true
304+ #=
305+ _similar_affect!
306+
307+ Return a new LindbladJump with the same fields as the input LindbladJump but with new memory.
308+ =#
309+ function _similar_affect! (affect:: LindbladJump , traj_rng)
310+ random_n = Ref (rand (traj_rng))
311+ cache_mc = similar (affect. cache_mc)
312+ weights_mc = similar (affect. weights_mc)
313+ cumsum_weights_mc = similar (affect. cumsum_weights_mc)
314+ jump_times = similar (affect. jump_times)
315+ jump_which = similar (affect. jump_which)
316+ jump_times_which_idx = Ref (1 )
317+
318+ return LindbladJump (
319+ affect. c_ops,
320+ affect. c_ops_herm,
321+ traj_rng,
322+ random_n,
323+ cache_mc,
324+ weights_mc,
325+ cumsum_weights_mc,
326+ jump_times,
327+ jump_which,
328+ jump_times_which_idx,
329+ )
330+ end
331+
332+ function _modify_field (obj:: T , field_name:: Symbol , field_val) where {T}
333+ # Create a NamedTuple of fields, deepcopying only the selected ones
334+ fields = (name != field_name ? (getfield (obj, name)) : field_val for name in fieldnames (T))
335+ # Reconstruct the struct with the updated fields
336+ return Base. typename (T). wrapper (fields... )
337+ end
192338
193339_mcsolve_has_continuous_jump (cb:: CallbackSet ) =
194340 (length (cb. continuous_callbacks) > 0 ) && (cb. continuous_callbacks[1 ]. affect! isa LindbladJump)
0 commit comments