@@ -4,13 +4,43 @@ This file contains helper functions for callbacks. The affect! function are defi
44
55# #
66
7+ abstract type AbstractSaveFunc end
8+
79# Multiple dispatch depending on the progress_bar and e_ops types
810function _generate_se_me_kwargs (e_ops, progress_bar, tlist, kwargs, method)
911 cb = _generate_save_callback (e_ops, tlist, progress_bar, method)
1012 return _merge_kwargs_with_callback (kwargs, cb)
1113end
1214_generate_se_me_kwargs (e_ops:: Nothing , progress_bar:: Val{false} , tlist, kwargs, method) = kwargs
1315
16+ function _generate_stochastic_kwargs (
17+ e_ops,
18+ sc_ops,
19+ progress_bar,
20+ tlist,
21+ store_measurement,
22+ kwargs,
23+ method:: Type{SF} ,
24+ ) where {SF<: AbstractSaveFunc }
25+ cb_save = _generate_stochastic_save_callback (e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
26+
27+ if SF === SaveFuncSSESolve
28+ cb_normalize = _ssesolve_generate_normalize_cb ()
29+ return _merge_kwargs_with_callback (kwargs, CallbackSet (cb_normalize, cb_save))
30+ end
31+
32+ return _merge_kwargs_with_callback (kwargs, cb_save)
33+ end
34+ _generate_stochastic_kwargs (
35+ e_ops:: Nothing ,
36+ sc_ops,
37+ progress_bar:: Val{false} ,
38+ tlist,
39+ store_measurement:: Val{false} ,
40+ kwargs,
41+ method:: Type{SF} ,
42+ ) where {SF<: AbstractSaveFunc } = kwargs
43+
1444function _merge_kwargs_with_callback (kwargs, cb)
1545 kwargs2 =
1646 haskey (kwargs, :callback ) ? merge (kwargs, (callback = CallbackSet (cb, kwargs. callback),)) :
@@ -30,77 +60,111 @@ function _generate_save_callback(e_ops, tlist, progress_bar, method)
3060 return PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
3161end
3262
33- _get_e_ops_data (e_ops, :: Type{SaveFuncSESolve} ) = get_data .(e_ops)
34- _get_e_ops_data (e_ops, :: Type{SaveFuncMESolve} ) = [_generate_mesolve_e_op (op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10
35- _get_e_ops_data (e_ops, :: Type{SaveFuncSSESolve} ) = get_data .(e_ops)
36-
37- _generate_mesolve_e_op (op) = mat2vec (adjoint (get_data (op)))
38-
39- #=
40- This function add the normalization callback to the kwargs. It is needed to stabilize the integration when using the ssesolve method.
41- =#
42- function _ssesolve_add_normalize_cb (kwargs)
43- _condition = (u, t, integrator) -> true
44- _affect! = (integrator) -> normalize! (integrator. u)
45- cb = DiscreteCallback (_condition, _affect!; save_positions = (false , false ))
46- # return merge(kwargs, (callback = CallbackSet(kwargs[:callback], cb),))
63+ function _generate_stochastic_save_callback (e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
64+ e_ops_data = e_ops isa Nothing ? nothing : _get_e_ops_data (e_ops, method)
65+ m_ops_data = _get_m_ops_data (sc_ops, method)
4766
48- cb_set = haskey (kwargs, :callback ) ? CallbackSet (kwargs[ :callback ], cb) : cb
67+ progr = getVal (progress_bar ) ? ProgressBar ( length (tlist), enable = getVal (progress_bar)) : nothing
4968
50- kwargs2 = merge (kwargs, (callback = cb_set,))
69+ expvals = e_ops isa Nothing ? nothing : Array {ComplexF64} (undef, length (e_ops), length (tlist))
70+ m_expvals = getVal (store_measurement) ? Array {Float64} (undef, length (sc_ops), length (tlist) - 1 ) : nothing
5171
52- return kwargs2
72+ _save_affect! = method (store_measurement, e_ops_data, m_ops_data, progr, Ref (1 ), expvals, m_expvals)
73+ return PresetTimeCallback (tlist, _save_affect!, save_positions = (false , false ))
5374end
5475
5576# #
5677
57- # When e_ops is Nothing. Common for both mesolve and sesolve
78+ # When e_ops is Nothing. Common for all solvers
5879function _save_func (integrator, progr)
5980 next! (progr)
6081 u_modified! (integrator, false )
6182 return nothing
6283end
6384
64- # When progr is Nothing. Common for both mesolve and sesolve
85+ # When progr is Nothing. Common for all solvers
6586function _save_func (integrator, progr:: Nothing )
6687 u_modified! (integrator, false )
6788 return nothing
6889end
6990
7091# #
7192
93+ #=
94+ To extract the measurement outcomes of a stochastic solver
95+ =#
96+ function _get_m_expvals (integrator:: AbstractODESolution , method:: Type{SF} ) where {SF<: AbstractSaveFunc }
97+ cb = _get_save_callback (integrator, method)
98+ if cb isa Nothing
99+ return nothing
100+ else
101+ return cb. affect!. m_expvals
102+ end
103+ end
104+
105+ #=
106+ With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
107+ This callback can only be a PresetTimeCallback (DiscreteCallback).
108+ =#
109+ function _get_e_ops (integrator:: AbstractODEIntegrator , method:: Type{SF} ) where {SF<: AbstractSaveFunc }
110+ cb = _get_save_callback (integrator, method)
111+ if cb isa Nothing
112+ return nothing
113+ else
114+ return cb. affect!. e_ops
115+ end
116+ end
117+
72118# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`.
73- function _se_me_sse_get_expvals (sol:: AbstractODESolution )
74- cb = _se_me_sse_get_save_callback (sol)
119+ function _get_expvals (sol:: AbstractODESolution , method :: Type{SF} ) where {SF <: AbstractSaveFunc }
120+ cb = _get_save_callback (sol, method )
75121 if cb isa Nothing
76122 return nothing
77123 else
78124 return cb. affect!. expvals
79125 end
80126end
81127
82- function _se_me_sse_get_save_callback (sol:: AbstractODESolution )
128+ #=
129+ _get_save_callback
130+
131+ Return the Callback that is responsible for saving the expectation values of the system.
132+ =#
133+ function _get_save_callback (sol:: AbstractODESolution , method:: Type{SF} ) where {SF<: AbstractSaveFunc }
83134 kwargs = NamedTuple (sol. prob. kwargs) # Convert to NamedTuple to support Zygote.jl
84135 if hasproperty (kwargs, :callback )
85- return _se_me_sse_get_save_callback (kwargs. callback)
136+ return _get_save_callback (kwargs. callback, method )
86137 else
87138 return nothing
88139 end
89140end
90- _se_me_sse_get_save_callback (integrator:: AbstractODEIntegrator ) = _se_me_sse_get_save_callback (integrator. opts. callback)
91- function _se_me_sse_get_save_callback (cb:: CallbackSet )
141+ _get_save_callback (integrator:: AbstractODEIntegrator , method:: Type{SF} ) where {SF<: AbstractSaveFunc } =
142+ _get_save_callback (integrator. opts. callback, method)
143+ function _get_save_callback (cb:: CallbackSet , method:: Type{SF} ) where {SF<: AbstractSaveFunc }
92144 cbs_discrete = cb. discrete_callbacks
93145 if length (cbs_discrete) > 0
94- _cb = cb. discrete_callbacks[1 ]
95- return _se_me_sse_get_save_callback (_cb)
146+ idx = _get_save_callback_idx (cb, method)
147+ _cb = cb. discrete_callbacks[idx]
148+ return _get_save_callback (_cb, method)
96149 else
97150 return nothing
98151 end
99152end
100- function _se_me_sse_get_save_callback (cb:: DiscreteCallback )
101- if typeof (cb. affect!) <: Union{SaveFuncSESolve,SaveFuncMESolve,SaveFuncSSESolve}
153+ function _get_save_callback (cb:: DiscreteCallback , :: Type{SF} ) where {SF <: AbstractSaveFunc }
154+ if typeof (cb. affect!) <: AbstractSaveFunc
102155 return cb
103156 end
104157 return nothing
105158end
106- _se_me_sse_get_save_callback (cb:: ContinuousCallback ) = nothing
159+ _get_save_callback (cb:: ContinuousCallback , :: Type{SF} ) where {SF<: AbstractSaveFunc } = nothing
160+
161+ _get_save_callback_idx (cb, method) = 1
162+
163+ # %% ------------ Noise Measurement Helpers ------------ %%
164+
165+ # TODO : Add some cache mechanism to avoid memory allocations
166+ function _homodyne_dWdt (integrator)
167+ @inbounds _dWdt = (integrator. W. u[end ] .- integrator. W. u[end - 1 ]) ./ (integrator. W. t[end ] - integrator. W. t[end - 1 ])
168+
169+ return _dWdt
170+ end
0 commit comments