Skip to content

Commit ad896ff

Browse files
Working case of ssesolve measurement
1 parent 3f6d46e commit ad896ff

File tree

7 files changed

+133
-19
lines changed

7 files changed

+133
-19
lines changed

src/QuantumToolbox.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ include("time_evolution/callback_helpers/sesolve_callback_helpers.jl")
101101
include("time_evolution/callback_helpers/mesolve_callback_helpers.jl")
102102
include("time_evolution/callback_helpers/mcsolve_callback_helpers.jl")
103103
include("time_evolution/callback_helpers/ssesolve_callback_helpers.jl")
104+
include("time_evolution/callback_helpers/smesolve_callback_helpers.jl")
104105
include("time_evolution/mesolve.jl")
105106
include("time_evolution/lr_mesolve.jl")
106107
include("time_evolution/sesolve.jl")

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@ function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method)
1313
end
1414
_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs
1515

16+
function _generate_stochastic_kwargs(
17+
e_ops,
18+
sc_ops,
19+
progress_bar,
20+
tlist,
21+
store_measurement,
22+
kwargs,
23+
method::Type{SF},
24+
) where {SF<:AbstractSaveFunc}
25+
cb_save = _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
26+
27+
if SF === SaveFuncSSESolve
28+
cb_normalize = _ssesolve_generate_normalize_cb()
29+
return _merge_kwargs_with_callback(kwargs, CallbackSet(cb_normalize, cb_save))
30+
end
31+
32+
return _merge_kwargs_with_callback(kwargs, cb_save)
33+
end
34+
_generate_stochastic_kwargs(
35+
e_ops::Nothing,
36+
sc_ops,
37+
progress_bar::Val{false},
38+
tlist,
39+
store_measurement::Val{false},
40+
kwargs,
41+
method::Type{SF},
42+
) where {SF<:AbstractSaveFunc} = kwargs
43+
1644
function _merge_kwargs_with_callback(kwargs, cb)
1745
kwargs2 =
1846
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) :
@@ -32,6 +60,19 @@ function _generate_save_callback(e_ops, tlist, progress_bar, method)
3260
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
3361
end
3462

63+
function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
64+
e_ops_data = e_ops isa Nothing ? nothing : _get_e_ops_data(e_ops, method)
65+
m_ops_data = _get_m_ops_data(sc_ops, method)
66+
67+
progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing
68+
69+
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
70+
m_expvals = getVal(store_measurement) ? Array{Float64}(undef, length(sc_ops), length(tlist) - 1) : nothing
71+
72+
_save_affect! = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
73+
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
74+
end
75+
3576
##
3677

3778
# When e_ops is Nothing. Common for all solvers
@@ -49,6 +90,18 @@ end
4990

5091
##
5192

93+
#=
94+
To extract the measurement outcomes of a stochastic solver
95+
=#
96+
function _get_m_expvals(integrator::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
97+
cb = _get_save_callback(integrator, method)
98+
if cb isa Nothing
99+
return nothing
100+
else
101+
return cb.affect!.m_expvals
102+
end
103+
end
104+
52105
#=
53106
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
54107
This callback can only be a PresetTimeCallback (DiscreteCallback).
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,71 @@
22
Helper functions for the ssesolve callbacks. Equal to the sesolve case, but with an additional normalization before saving the expectation values.
33
=#
44

5-
struct SaveFuncSSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}} <: AbstractSaveFunc
5+
struct SaveFuncSSESolve{
6+
SM,
7+
TE,
8+
TME,
9+
PT<:Union{Nothing,ProgressBar},
10+
IT,
11+
TEXPV<:Union{Nothing,AbstractMatrix},
12+
TMEXPV<:Union{Nothing,AbstractMatrix},
13+
} <: AbstractSaveFunc
14+
store_measurement::Val{SM}
615
e_ops::TE
16+
m_ops::TME
717
progr::PT
818
iter::IT
919
expvals::TEXPV
20+
m_expvals::TMEXPV
1021
end
1122

12-
(f::SaveFuncSSESolve)(integrator) = _save_func_ssesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
13-
(f::SaveFuncSSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
23+
(f::SaveFuncSSESolve)(integrator) =
24+
_save_func_ssesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
25+
(f::SaveFuncSSESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
1426

1527
_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)
28+
_get_m_ops_data(sc_ops, ::Type{SaveFuncSSESolve}) = map(op -> Hermitian(get_data(op) + get_data(op)'), sc_ops)
1629

17-
_get_save_callback_idx(cb, ::Type{SaveFuncSSESolve}) = 1 # The first one is the normalization callback
30+
_get_save_callback_idx(cb, ::Type{SaveFuncSSESolve}) = 2 # The first one is the normalization callback
1831

1932
##
2033

2134
# When e_ops is a list of operators
22-
function _save_func_ssesolve(integrator, e_ops, progr, iter, expvals)
23-
ψ = normalize!(integrator.u)
35+
function _save_func_ssesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
36+
ψ = integrator.u
37+
2438
_expect = op -> dot(ψ, op, ψ)
25-
@. expvals[:, iter[]] = _expect(e_ops)
39+
40+
if !isnothing(e_ops)
41+
@. expvals[:, iter[]] = _expect(e_ops)
42+
end
43+
44+
if !isnothing(m_expvals) && iter[] > 1
45+
_dWdt = _homodyne_dWdt(integrator)
46+
@. m_expvals[:, iter[]-1] = real(_expect(m_ops)) + _dWdt
47+
end
48+
2649
iter[] += 1
2750

2851
return _save_func(integrator, progr)
2952
end
3053

54+
# TODO: Add some cache mechanism to avoid memory allocations
55+
function _homodyne_dWdt(integrator)
56+
@inbounds _dWdt = (integrator.W.u[end] .- integrator.W.u[end-1]) ./ (integrator.W.t[end] - integrator.W.t[end-1])
57+
58+
return _dWdt
59+
end
60+
3161
##
3262

3363
#=
34-
This function adds the normalization callback to the kwargs. It is needed to stabilize the integration when using the ssesolve method.
64+
This function generates the normalization callback. It is needed to stabilize the integration when using the ssesolve method.
3565
=#
36-
function _ssesolve_add_normalize_cb(kwargs)
66+
function _ssesolve_generate_normalize_cb()
3767
_condition = (u, t, integrator) -> true
3868
_affect! = (integrator) -> normalize!(integrator.u)
3969
cb = DiscreteCallback(_condition, _affect!; save_positions = (false, false))
4070

41-
cb_set = haskey(kwargs, :callback) ? CallbackSet(kwargs[:callback], cb) : cb
42-
43-
kwargs2 = merge(kwargs, (callback = cb_set,))
44-
45-
return kwargs2
71+
return cb
4672
end

src/time_evolution/smesolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ function smesolve(
376376
expvals,
377377
expvals, # This is average_expect
378378
expvals_all,
379+
nothing, # Measurement expectation values
379380
sol.converged,
380381
_sol_1.alg,
381382
_sol_1.prob.kwargs[:abstol],

src/time_evolution/ssesolve.jl

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,15 @@ function ssesolveProblem(
114114
D = DiffusionOperator(D_l)
115115

116116
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_SDE_SOLVER_OPTIONS; kwargs...)
117-
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSSESolve)
118-
kwargs4 = _ssesolve_add_normalize_cb(kwargs3)
117+
kwargs3 = _generate_stochastic_kwargs(
118+
e_ops,
119+
sc_ops,
120+
makeVal(progress_bar),
121+
tlist,
122+
makeVal(store_measurement),
123+
kwargs2,
124+
SaveFuncSSESolve,
125+
)
119126

120127
tspan = (tlist[1], tlist[end])
121128
noise = RealWienerProcess!(
@@ -134,7 +141,7 @@ function ssesolveProblem(
134141
params;
135142
noise_rate_prototype = noise_rate_prototype,
136143
noise = noise,
137-
kwargs4...,
144+
kwargs3...,
138145
)
139146

140147
return TimeEvolutionProblem(prob, tlist, dims)
@@ -154,6 +161,7 @@ end
154161
prob_func::Union{Function, Nothing} = nothing,
155162
output_func::Union{Tuple,Nothing} = nothing,
156163
progress_bar::Union{Val,Bool} = Val(true),
164+
store_measurement::Union{Val,Bool} = Val(false),
157165
kwargs...,
158166
)
159167
@@ -193,6 +201,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is t
193201
- `prob_func`: Function to use for generating the SDEProblem.
194202
- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
195203
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
204+
- `store_measurement`: Whether to store the measurement results. Default is `Val(false)`.
196205
- `kwargs`: The keyword arguments for the ODEProblem.
197206
198207
# Notes
@@ -219,11 +228,19 @@ function ssesolveEnsembleProblem(
219228
prob_func::Union{Function,Nothing} = nothing,
220229
output_func::Union{Tuple,Nothing} = nothing,
221230
progress_bar::Union{Val,Bool} = Val(true),
231+
store_measurement::Union{Val,Bool} = Val(false),
222232
kwargs...,
223233
)
224234
_prob_func =
225235
isnothing(prob_func) ?
226-
_ensemble_dispatch_prob_func(rng, ntraj, tlist, _stochastic_prob_func; n_sc_ops = length(sc_ops)) : prob_func
236+
_ensemble_dispatch_prob_func(
237+
rng,
238+
ntraj,
239+
tlist,
240+
_stochastic_prob_func;
241+
n_sc_ops = length(sc_ops),
242+
store_measurement = makeVal(store_measurement),
243+
) : prob_func
227244
_output_func =
228245
output_func isa Nothing ?
229246
_ensemble_dispatch_output_func(ensemble_method, progress_bar, ntraj, _stochastic_output_func) : output_func
@@ -237,6 +254,7 @@ function ssesolveEnsembleProblem(
237254
params = params,
238255
rng = rng,
239256
progress_bar = Val(false),
257+
store_measurement = makeVal(store_measurement),
240258
kwargs...,
241259
)
242260

@@ -265,6 +283,7 @@ end
265283
prob_func::Union{Function, Nothing} = nothing,
266284
output_func::Union{Tuple,Nothing} = nothing,
267285
progress_bar::Union{Val,Bool} = Val(true),
286+
store_measurement::Union{Val,Bool} = Val(false),
268287
kwargs...,
269288
)
270289
@@ -307,6 +326,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is th
307326
- `prob_func`: Function to use for generating the SDEProblem.
308327
- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
309328
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
329+
- `store_measurement`: Whether to store the measurement results. Default is `Val(false)`.
310330
- `kwargs`: The keyword arguments for the ODEProblem.
311331
312332
# Notes
@@ -335,6 +355,7 @@ function ssesolve(
335355
prob_func::Union{Function,Nothing} = nothing,
336356
output_func::Union{Tuple,Nothing} = nothing,
337357
progress_bar::Union{Val,Bool} = Val(true),
358+
store_measurement::Union{Val,Bool} = Val(false),
338359
kwargs...,
339360
)
340361
ens_prob = ssesolveEnsembleProblem(
@@ -350,6 +371,7 @@ function ssesolve(
350371
prob_func = prob_func,
351372
output_func = output_func,
352373
progress_bar = progress_bar,
374+
store_measurement = makeVal(store_measurement),
353375
kwargs...,
354376
)
355377

@@ -366,6 +388,7 @@ function ssesolve(
366388

367389
_sol_1 = sol[:, 1]
368390
_expvals_sol_1 = _get_expvals(_sol_1, SaveFuncSSESolve)
391+
_m_expvals_sol_1 = _get_m_expvals(_sol_1, SaveFuncSSESolve)
369392

370393
normalize_states = Val(false)
371394
dims = ens_prob.dimensions
@@ -374,6 +397,10 @@ function ssesolve(
374397
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
375398
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
376399

400+
_m_expvals =
401+
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol))
402+
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2)
403+
377404
expvals =
378405
_get_expvals(_sol_1, SaveFuncSSESolve) isa Nothing ? nothing :
379406
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
@@ -385,6 +412,7 @@ function ssesolve(
385412
expvals,
386413
expvals, # This is average_expect
387414
expvals_all,
415+
m_expvals, # Measurement expectation values
388416
sol.converged,
389417
_sol_1.alg,
390418
_sol_1.prob.kwargs[:abstol],

src/time_evolution/time_evolution.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ struct TimeEvolutionStochasticSol{
173173
TS<:AbstractVector,
174174
TE<:Union{AbstractMatrix,Nothing},
175175
TEA<:Union{AbstractArray,Nothing},
176+
TEM<:Union{AbstractArray,Nothing},
176177
AlgT<:StochasticDiffEqAlgorithm,
177178
AT<:Real,
178179
RT<:Real,
@@ -183,6 +184,7 @@ struct TimeEvolutionStochasticSol{
183184
expect::TE
184185
average_expect::TE # Currently just a synonym for `expect`
185186
runs_expect::TEA
187+
measurement::TEM
186188
converged::Bool
187189
alg::AlgT
188190
abstol::AT
@@ -345,11 +347,13 @@ function _stochastic_prob_func(prob, i, repeat, rng, seeds, tlist; kwargs...)
345347
traj_rng = typeof(rng)()
346348
seed!(traj_rng, seed)
347349

350+
store_measurement = haskey(kwargs, :store_measurement) ? getVal(kwargs[:store_measurement]) : false
351+
348352
noise = RealWienerProcess!(
349353
prob.prob.tspan[1],
350354
zeros(kwargs[:n_sc_ops]),
351355
zeros(kwargs[:n_sc_ops]),
352-
save_everystep = false,
356+
save_everystep = store_measurement,
353357
rng = traj_rng,
354358
)
355359

0 commit comments

Comments
 (0)