Skip to content

Commit 49e9ffc

Browse files
Introduce measurement on ssesolve and smesolve (#404)
1 parent 81724ad commit 49e9ffc

17 files changed

+395
-165
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
- Align some attributes of `mcsolve`, `ssesolve` and `smesolve` results with `QuTiP`. ([#402])
1717
- Improve ensemble generation of `ssesolve` and change parameters handling on stochastic processes. ([#403])
1818
- Set default trajectories to 500 and rename the keyword argument `ensemble_method` to `ensemblealg`. ([#405])
19+
- Introduce measurement on `ssesolve` and `smesolve`. ([#404])
1920

2021
## [v0.26.0]
2122
Release date: 2025-02-09
@@ -133,4 +134,5 @@ Release date: 2024-11-13
133134
[#398]: https://github.com/qutip/QuantumToolbox.jl/issues/398
134135
[#402]: https://github.com/qutip/QuantumToolbox.jl/issues/402
135136
[#403]: https://github.com/qutip/QuantumToolbox.jl/issues/403
137+
[#404]: https://github.com/qutip/QuantumToolbox.jl/issues/404
136138
[#405]: https://github.com/qutip/QuantumToolbox.jl/issues/405

docs/src/users_guide/time_evolution/stochastic.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,16 @@ sse_sol = ssesolve(
105105
sc_ops,
106106
e_ops = [x],
107107
ntraj = ntraj,
108+
store_measurement = Val(true),
108109
)
109110
111+
measurement_avg = sum(sse_sol.measurement, dims=2) / size(sse_sol.measurement, 2)
112+
measurement_avg = dropdims(measurement_avg, dims=2)
113+
110114
# plot by CairoMakie.jl
111115
fig = Figure(size = (500, 350))
112116
ax = Axis(fig[1, 1], xlabel = "Time")
113-
#lines!(ax, tlist, real(sse_sol.xxxxxx), label = L"J_x", color = :red, linestyle = :solid) TODO: add this in the future
117+
lines!(ax, tlist[1:end-1], real(measurement_avg[1,:]), label = L"J_x", color = :red, linestyle = :solid)
114118
lines!(ax, tlist, real(sse_sol.expect[1,:]), label = L"\langle x \rangle", color = :black, linestyle = :solid)
115119
116120
axislegend(ax, position = :rt)
@@ -134,12 +138,16 @@ sme_sol = smesolve(
134138
sc_ops,
135139
e_ops = [x],
136140
ntraj = ntraj,
141+
store_measurement = Val(true),
137142
)
138143
144+
measurement_avg = sum(sme_sol.measurement, dims=2) / size(sme_sol.measurement, 2)
145+
measurement_avg = dropdims(measurement_avg, dims=2)
146+
139147
# plot by CairoMakie.jl
140148
fig = Figure(size = (500, 350))
141149
ax = Axis(fig[1, 1], xlabel = "Time")
142-
#lines!(ax, tlist, real(sme_sol.xxxxxx), label = L"J_x", color = :red, linestyle = :solid) TODO: add this in the future
150+
lines!(ax, tlist[1:end-1], real(measurement_avg[1,:]), label = L"J_x", color = :red, linestyle = :solid)
143151
lines!(ax, tlist, real(sme_sol.expect[1,:]), label = L"\langle x \rangle", color = :black, linestyle = :solid)
144152
145153
axislegend(ax, position = :rt)

src/QuantumToolbox.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,12 @@ include("qobj/block_diagonal_form.jl")
9797

9898
# time evolution
9999
include("time_evolution/time_evolution.jl")
100+
include("time_evolution/callback_helpers/callback_helpers.jl")
100101
include("time_evolution/callback_helpers/sesolve_callback_helpers.jl")
101102
include("time_evolution/callback_helpers/mesolve_callback_helpers.jl")
102103
include("time_evolution/callback_helpers/mcsolve_callback_helpers.jl")
103104
include("time_evolution/callback_helpers/ssesolve_callback_helpers.jl")
104-
include("time_evolution/callback_helpers/callback_helpers.jl")
105+
include("time_evolution/callback_helpers/smesolve_callback_helpers.jl")
105106
include("time_evolution/mesolve.jl")
106107
include("time_evolution/lr_mesolve.jl")
107108
include("time_evolution/sesolve.jl")

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 94 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,43 @@ This file contains helper functions for callbacks. The affect! function are defi
44

55
##
66

7+
abstract type AbstractSaveFunc end
8+
79
# Multiple dispatch depending on the progress_bar and e_ops types
810
function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method)
911
cb = _generate_save_callback(e_ops, tlist, progress_bar, method)
1012
return _merge_kwargs_with_callback(kwargs, cb)
1113
end
1214
_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs
1315

16+
function _generate_stochastic_kwargs(
17+
e_ops,
18+
sc_ops,
19+
progress_bar,
20+
tlist,
21+
store_measurement,
22+
kwargs,
23+
method::Type{SF},
24+
) where {SF<:AbstractSaveFunc}
25+
cb_save = _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
26+
27+
if SF === SaveFuncSSESolve
28+
cb_normalize = _ssesolve_generate_normalize_cb()
29+
return _merge_kwargs_with_callback(kwargs, CallbackSet(cb_normalize, cb_save))
30+
end
31+
32+
return _merge_kwargs_with_callback(kwargs, cb_save)
33+
end
34+
_generate_stochastic_kwargs(
35+
e_ops::Nothing,
36+
sc_ops,
37+
progress_bar::Val{false},
38+
tlist,
39+
store_measurement::Val{false},
40+
kwargs,
41+
method::Type{SF},
42+
) where {SF<:AbstractSaveFunc} = kwargs
43+
1444
function _merge_kwargs_with_callback(kwargs, cb)
1545
kwargs2 =
1646
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) :
@@ -30,77 +60,111 @@ function _generate_save_callback(e_ops, tlist, progress_bar, method)
3060
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
3161
end
3262

33-
_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)
34-
_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10
35-
_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)
36-
37-
_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op)))
38-
39-
#=
40-
This function add the normalization callback to the kwargs. It is needed to stabilize the integration when using the ssesolve method.
41-
=#
42-
function _ssesolve_add_normalize_cb(kwargs)
43-
_condition = (u, t, integrator) -> true
44-
_affect! = (integrator) -> normalize!(integrator.u)
45-
cb = DiscreteCallback(_condition, _affect!; save_positions = (false, false))
46-
# return merge(kwargs, (callback = CallbackSet(kwargs[:callback], cb),))
63+
function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
64+
e_ops_data = e_ops isa Nothing ? nothing : _get_e_ops_data(e_ops, method)
65+
m_ops_data = _get_m_ops_data(sc_ops, method)
4766

48-
cb_set = haskey(kwargs, :callback) ? CallbackSet(kwargs[:callback], cb) : cb
67+
progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing
4968

50-
kwargs2 = merge(kwargs, (callback = cb_set,))
69+
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
70+
m_expvals = getVal(store_measurement) ? Array{Float64}(undef, length(sc_ops), length(tlist) - 1) : nothing
5171

52-
return kwargs2
72+
_save_affect! = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
73+
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
5374
end
5475

5576
##
5677

57-
# When e_ops is Nothing. Common for both mesolve and sesolve
78+
# When e_ops is Nothing. Common for all solvers
5879
function _save_func(integrator, progr)
5980
next!(progr)
6081
u_modified!(integrator, false)
6182
return nothing
6283
end
6384

64-
# When progr is Nothing. Common for both mesolve and sesolve
85+
# When progr is Nothing. Common for all solvers
6586
function _save_func(integrator, progr::Nothing)
6687
u_modified!(integrator, false)
6788
return nothing
6889
end
6990

7091
##
7192

93+
#=
94+
To extract the measurement outcomes of a stochastic solver
95+
=#
96+
function _get_m_expvals(integrator::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
97+
cb = _get_save_callback(integrator, method)
98+
if cb isa Nothing
99+
return nothing
100+
else
101+
return cb.affect!.m_expvals
102+
end
103+
end
104+
105+
#=
106+
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
107+
This callback can only be a PresetTimeCallback (DiscreteCallback).
108+
=#
109+
function _get_e_ops(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc}
110+
cb = _get_save_callback(integrator, method)
111+
if cb isa Nothing
112+
return nothing
113+
else
114+
return cb.affect!.e_ops
115+
end
116+
end
117+
72118
# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`.
73-
function _se_me_sse_get_expvals(sol::AbstractODESolution)
74-
cb = _se_me_sse_get_save_callback(sol)
119+
function _get_expvals(sol::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
120+
cb = _get_save_callback(sol, method)
75121
if cb isa Nothing
76122
return nothing
77123
else
78124
return cb.affect!.expvals
79125
end
80126
end
81127

82-
function _se_me_sse_get_save_callback(sol::AbstractODESolution)
128+
#=
129+
_get_save_callback
130+
131+
Return the Callback that is responsible for saving the expectation values of the system.
132+
=#
133+
function _get_save_callback(sol::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
83134
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
84135
if hasproperty(kwargs, :callback)
85-
return _se_me_sse_get_save_callback(kwargs.callback)
136+
return _get_save_callback(kwargs.callback, method)
86137
else
87138
return nothing
88139
end
89140
end
90-
_se_me_sse_get_save_callback(integrator::AbstractODEIntegrator) = _se_me_sse_get_save_callback(integrator.opts.callback)
91-
function _se_me_sse_get_save_callback(cb::CallbackSet)
141+
_get_save_callback(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc} =
142+
_get_save_callback(integrator.opts.callback, method)
143+
function _get_save_callback(cb::CallbackSet, method::Type{SF}) where {SF<:AbstractSaveFunc}
92144
cbs_discrete = cb.discrete_callbacks
93145
if length(cbs_discrete) > 0
94-
_cb = cb.discrete_callbacks[1]
95-
return _se_me_sse_get_save_callback(_cb)
146+
idx = _get_save_callback_idx(cb, method)
147+
_cb = cb.discrete_callbacks[idx]
148+
return _get_save_callback(_cb, method)
96149
else
97150
return nothing
98151
end
99152
end
100-
function _se_me_sse_get_save_callback(cb::DiscreteCallback)
101-
if typeof(cb.affect!) <: Union{SaveFuncSESolve,SaveFuncMESolve,SaveFuncSSESolve}
153+
function _get_save_callback(cb::DiscreteCallback, ::Type{SF}) where {SF<:AbstractSaveFunc}
154+
if typeof(cb.affect!) <: AbstractSaveFunc
102155
return cb
103156
end
104157
return nothing
105158
end
106-
_se_me_sse_get_save_callback(cb::ContinuousCallback) = nothing
159+
_get_save_callback(cb::ContinuousCallback, ::Type{SF}) where {SF<:AbstractSaveFunc} = nothing
160+
161+
_get_save_callback_idx(cb, method) = 1
162+
163+
# %% ------------ Noise Measurement Helpers ------------ %%
164+
165+
# TODO: Add some cache mechanism to avoid memory allocations
166+
function _homodyne_dWdt(integrator)
167+
@inbounds _dWdt = (integrator.W.u[end] .- integrator.W.u[end-1]) ./ (integrator.W.t[end] - integrator.W.t[end-1])
168+
169+
return _dWdt
170+
end

src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
Helper functions for the mcsolve callbacks.
33
=#
44

5-
struct SaveFuncMCSolve{TE,IT,TEXPV}
5+
struct SaveFuncMCSolve{TE,IT,TEXPV} <: AbstractSaveFunc
66
e_ops::TE
77
iter::IT
88
expvals::TEXPV
99
end
1010

1111
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals)
1212

13+
_get_save_callback_idx(cb, ::Type{SaveFuncMCSolve}) = _mcsolve_has_continuous_jump(cb) ? 1 : 2
14+
15+
##
1316
struct LindbladJump{
1417
T1,
1518
T2,
@@ -167,37 +170,6 @@ _mcsolve_discrete_condition(u, t, integrator) =
167170

168171
##
169172

170-
#=
171-
_mc_get_save_callback
172-
173-
Return the Callback that is responsible for saving the expectation values of the system.
174-
=#
175-
function _mc_get_save_callback(sol::AbstractODESolution)
176-
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
177-
return _mc_get_save_callback(kwargs.callback) # There is always the Jump callback
178-
end
179-
_mc_get_save_callback(integrator::AbstractODEIntegrator) = _mc_get_save_callback(integrator.opts.callback)
180-
function _mc_get_save_callback(cb::CallbackSet)
181-
cbs_discrete = cb.discrete_callbacks
182-
183-
if length(cbs_discrete) > 0
184-
idx = _mcsolve_has_continuous_jump(cb) ? 1 : 2
185-
_cb = cb.discrete_callbacks[idx]
186-
return _mc_get_save_callback(_cb)
187-
else
188-
return nothing
189-
end
190-
end
191-
_mc_get_save_callback(cb::DiscreteCallback) =
192-
if cb.affect! isa SaveFuncMCSolve
193-
return cb
194-
else
195-
return nothing
196-
end
197-
_mc_get_save_callback(cb::ContinuousCallback) = nothing
198-
199-
##
200-
201173
function _mc_get_jump_callback(sol::AbstractODESolution)
202174
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
203175
return _mc_get_jump_callback(kwargs.callback) # There is always the Jump callback
@@ -215,8 +187,8 @@ _mc_get_jump_callback(cb::DiscreteCallback) = cb
215187
##
216188

217189
#=
218-
With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
219-
This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
190+
With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
191+
This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
220192
=#
221193
function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator)
222194
cb = _mc_get_jump_callback(integrator)
@@ -227,28 +199,6 @@ function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator)
227199
end
228200
end
229201

230-
#=
231-
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
232-
This callback can only be a PresetTimeCallback (DiscreteCallback).
233-
=#
234-
function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator)
235-
cb = _mc_get_save_callback(integrator)
236-
if cb isa Nothing
237-
return nothing
238-
else
239-
return cb.affect!.e_ops
240-
end
241-
end
242-
243-
function _mcsolve_get_expvals(sol::AbstractODESolution)
244-
cb = _mc_get_save_callback(sol)
245-
if cb isa Nothing
246-
return nothing
247-
else
248-
return cb.affect!.expvals
249-
end
250-
end
251-
252202
#=
253203
_mcsolve_initialize_callbacks(prob, tlist)
254204

src/time_evolution/callback_helpers/mesolve_callback_helpers.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Helper functions for the mesolve callbacks.
33
=#
44

5-
struct SaveFuncMESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
5+
struct SaveFuncMESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}} <: AbstractSaveFunc
66
e_ops::TE
77
progr::PT
88
iter::IT
@@ -12,6 +12,8 @@ end
1212
(f::SaveFuncMESolve)(integrator) = _save_func_mesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
1313
(f::SaveFuncMESolve{Nothing})(integrator) = _save_func(integrator, f.progr)
1414

15+
_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10
16+
1517
##
1618

1719
# When e_ops is a list of operators
@@ -29,11 +31,13 @@ function _save_func_mesolve(integrator, e_ops, progr, iter, expvals)
2931
end
3032

3133
function _mesolve_callbacks_new_e_ops!(integrator::AbstractODEIntegrator, e_ops)
32-
cb = _se_me_sse_get_save_callback(integrator)
34+
cb = _get_save_callback(integrator, SaveFuncMESolve)
3335
if cb isa Nothing
3436
return nothing
3537
else
3638
cb.affect!.e_ops .= e_ops # Only works if e_ops is a Vector of operators
3739
return nothing
3840
end
3941
end
42+
43+
_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op)))

src/time_evolution/callback_helpers/sesolve_callback_helpers.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Helper functions for the sesolve callbacks.
33
=#
44

5-
struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
5+
struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}} <: AbstractSaveFunc
66
e_ops::TE
77
progr::PT
88
iter::IT
@@ -12,6 +12,8 @@ end
1212
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
1313
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve
1414

15+
_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)
16+
1517
##
1618

1719
# When e_ops is a list of operators

0 commit comments

Comments
 (0)