Skip to content

Commit 19c1e0c

Browse files
Working smesolve
1 parent ad896ff commit 19c1e0c

File tree

5 files changed

+106
-16
lines changed

5 files changed

+106
-16
lines changed

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,12 @@ end
159159
_get_save_callback(cb::ContinuousCallback, ::Type{SF}) where {SF<:AbstractSaveFunc} = nothing
160160

161161
_get_save_callback_idx(cb, method) = 1
162+
163+
# %% ------------ Noise Measurement Helpers ------------ %%
164+
165+
# TODO: Add some cache mechanism to avoid memory allocations
166+
function _homodyne_dWdt(integrator)
167+
@inbounds _dWdt = (integrator.W.u[end] .- integrator.W.u[end-1]) ./ (integrator.W.t[end] - integrator.W.t[end-1])
168+
169+
return _dWdt
170+
end
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,55 @@
1+
#=
2+
Helper functions for the smesolve callbacks. Almost equal to the mesolve case, but with an additional possibility to store the measurement operators expectation values.
3+
=#
14

5+
struct SaveFuncSMESolve{
6+
SM,
7+
TE,
8+
TME,
9+
PT<:Union{Nothing,ProgressBar},
10+
IT,
11+
TEXPV<:Union{Nothing,AbstractMatrix},
12+
TMEXPV<:Union{Nothing,AbstractMatrix},
13+
} <: AbstractSaveFunc
14+
store_measurement::Val{SM}
15+
e_ops::TE
16+
m_ops::TME
17+
progr::PT
18+
iter::IT
19+
expvals::TEXPV
20+
m_expvals::TMEXPV
21+
end
22+
23+
(f::SaveFuncSMESolve)(integrator) =
24+
_save_func_smesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
25+
(f::SaveFuncSMESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
26+
27+
_get_e_ops_data(e_ops, ::Type{SaveFuncSMESolve}) = _get_e_ops_data(e_ops, SaveFuncMESolve)
28+
_get_m_ops_data(sc_ops, ::Type{SaveFuncSMESolve}) =
29+
map(op -> _generate_mesolve_e_op(op) + _generate_mesolve_e_op(op'), sc_ops)
30+
31+
##
32+
33+
# When e_ops is a list of operators
34+
function _save_func_smesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
35+
# This is equivalent to tr(op * ρ), when both are matrices.
36+
# The advantage of using this convention is that We don't need
37+
# to reshape u to make it a matrix, but we reshape the e_ops once.
38+
39+
ρ = integrator.u
40+
41+
_expect = op -> dot(op, ρ)
42+
43+
if !isnothing(e_ops)
44+
@. expvals[:, iter[]] = _expect(e_ops)
45+
end
46+
47+
if !isnothing(m_expvals) && iter[] > 1
48+
_dWdt = _homodyne_dWdt(integrator)
49+
@. m_expvals[:, iter[]-1] = real(_expect(m_ops)) + _dWdt
50+
end
51+
52+
iter[] += 1
53+
54+
return _save_func(integrator, progr)
55+
end

src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#=
2-
Helper functions for the ssesolve callbacks. Equal to the sesolve case, but with an additional normalization before saving the expectation values.
2+
Helper functions for the ssesolve callbacks. Almost equal to the sesolve case, but with an additional possibility to store the measurement operators expectation values. Also, this callback is not the first one, but the second one, due to the presence of the normalization callback.
33
=#
44

55
struct SaveFuncSSESolve{
@@ -51,13 +51,6 @@ function _save_func_ssesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_e
5151
return _save_func(integrator, progr)
5252
end
5353

54-
# TODO: Add some cache mechanism to avoid memory allocations
55-
function _homodyne_dWdt(integrator)
56-
@inbounds _dWdt = (integrator.W.u[end] .- integrator.W.u[end-1]) ./ (integrator.W.t[end] - integrator.W.t[end-1])
57-
58-
return _dWdt
59-
end
60-
6154
##
6255

6356
#=

src/time_evolution/smesolve.jl

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ _smesolve_ScalarOperator(op_vec) =
2020
params = NullParameters(),
2121
rng::AbstractRNG = default_rng(),
2222
progress_bar::Union{Val,Bool} = Val(true),
23+
store_measurement::Union{Val, Bool} = Val(false),
2324
kwargs...,
2425
)
2526
@@ -54,6 +55,7 @@ Above, ``\hat{C}_i`` represent the collapse operators related to pure dissipatio
5455
- `params`: `NullParameters` of parameters to pass to the solver.
5556
- `rng`: Random number generator for reproducibility.
5657
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
58+
- `store_measurement`: Whether to store the measurement expectation values. Default is `Val(false)`.
5759
- `kwargs`: The keyword arguments for the ODEProblem.
5860
5961
# Notes
@@ -77,6 +79,7 @@ function smesolveProblem(
7779
params = NullParameters(),
7880
rng::AbstractRNG = default_rng(),
7981
progress_bar::Union{Val,Bool} = Val(true),
82+
store_measurement::Union{Val,Bool} = Val(false),
8083
kwargs...,
8184
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject}}
8285
haskey(kwargs, :save_idxs) &&
@@ -111,11 +114,24 @@ function smesolveProblem(
111114
D = DiffusionOperator(D_l)
112115

113116
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_SDE_SOLVER_OPTIONS; kwargs...)
114-
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncMESolve)
117+
kwargs3 = _generate_stochastic_kwargs(
118+
e_ops,
119+
sc_ops,
120+
makeVal(progress_bar),
121+
tlist,
122+
makeVal(store_measurement),
123+
kwargs2,
124+
SaveFuncSMESolve,
125+
)
115126

116127
tspan = (tlist[1], tlist[end])
117-
noise =
118-
RealWienerProcess!(tlist[1], zeros(length(sc_ops)), zeros(length(sc_ops)), save_everystep = false, rng = rng)
128+
noise = RealWienerProcess!(
129+
tlist[1],
130+
zeros(length(sc_ops)),
131+
zeros(length(sc_ops)),
132+
save_everystep = getVal(store_measurement),
133+
rng = rng,
134+
)
119135
noise_rate_prototype = similar(ρ0, length(ρ0), length(sc_ops))
120136
prob = SDEProblem{true}(
121137
K,
@@ -146,6 +162,7 @@ end
146162
prob_func::Union{Function, Nothing} = nothing,
147163
output_func::Union{Tuple,Nothing} = nothing,
148164
progress_bar::Union{Val,Bool} = Val(true),
165+
store_measurement::Union{Val,Bool} = Val(false),
149166
kwargs...,
150167
)
151168
@@ -184,6 +201,7 @@ Above, ``\hat{C}_i`` represent the collapse operators related to pure dissipatio
184201
- `prob_func`: Function to use for generating the SDEProblem.
185202
- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
186203
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
204+
- `store_measurement`: Whether to store the measurement expectation values. Default is `Val(false)`.
187205
- `kwargs`: The keyword arguments for the ODEProblem.
188206
189207
# Notes
@@ -211,11 +229,19 @@ function smesolveEnsembleProblem(
211229
prob_func::Union{Function,Nothing} = nothing,
212230
output_func::Union{Tuple,Nothing} = nothing,
213231
progress_bar::Union{Val,Bool} = Val(true),
232+
store_measurement::Union{Val,Bool} = Val(false),
214233
kwargs...,
215234
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject}}
216235
_prob_func =
217236
isnothing(prob_func) ?
218-
_ensemble_dispatch_prob_func(rng, ntraj, tlist, _stochastic_prob_func; n_sc_ops = length(sc_ops)) : prob_func
237+
_ensemble_dispatch_prob_func(
238+
rng,
239+
ntraj,
240+
tlist,
241+
_stochastic_prob_func;
242+
n_sc_ops = length(sc_ops),
243+
store_measurement = makeVal(store_measurement),
244+
) : prob_func
219245
_output_func =
220246
output_func isa Nothing ?
221247
_ensemble_dispatch_output_func(ensemble_method, progress_bar, ntraj, _stochastic_output_func) : output_func
@@ -230,6 +256,7 @@ function smesolveEnsembleProblem(
230256
params = params,
231257
rng = rng,
232258
progress_bar = Val(false),
259+
store_measurement = makeVal(store_measurement),
233260
kwargs...,
234261
)
235262

@@ -259,6 +286,7 @@ end
259286
prob_func::Union{Function, Nothing} = nothing,
260287
output_func::Union{Tuple,Nothing} = nothing,
261288
progress_bar::Union{Val,Bool} = Val(true),
289+
store_measurement::Union{Val,Bool} = Val(false),
262290
kwargs...,
263291
)
264292
@@ -298,6 +326,7 @@ Above, ``\hat{C}_i`` represent the collapse operators related to pure dissipatio
298326
- `prob_func`: Function to use for generating the SDEProblem.
299327
- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
300328
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
329+
- `store_measurement`: Whether to store the measurement expectation values. Default is `Val(false)`.
301330
- `kwargs`: The keyword arguments for the ODEProblem.
302331
303332
# Notes
@@ -326,6 +355,7 @@ function smesolve(
326355
prob_func::Union{Function,Nothing} = nothing,
327356
output_func::Union{Tuple,Nothing} = nothing,
328357
progress_bar::Union{Val,Bool} = Val(true),
358+
store_measurement::Union{Val,Bool} = Val(false),
329359
kwargs...,
330360
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject}}
331361
ensemble_prob = smesolveEnsembleProblem(
@@ -342,6 +372,7 @@ function smesolve(
342372
prob_func = prob_func,
343373
output_func = output_func,
344374
progress_bar = progress_bar,
375+
store_measurement = makeVal(store_measurement),
345376
kwargs...,
346377
)
347378

@@ -358,13 +389,18 @@ function smesolve(
358389

359390
_sol_1 = sol[:, 1]
360391
_expvals_sol_1 = _get_expvals(_sol_1, SaveFuncMESolve)
392+
_m_expvals_sol_1 = _get_m_expvals(_sol_1, SaveFuncSMESolve)
361393

362394
dims = ens_prob.dimensions
363395
_expvals_all =
364396
_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_expvals(sol[:, i], SaveFuncMESolve), eachindex(sol))
365397
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
366398
states = map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims)), eachindex(sol))
367399

400+
_m_expvals =
401+
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSMESolve), eachindex(sol))
402+
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2)
403+
368404
expvals =
369405
_get_expvals(_sol_1, SaveFuncMESolve) isa Nothing ? nothing :
370406
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
@@ -376,7 +412,7 @@ function smesolve(
376412
expvals,
377413
expvals, # This is average_expect
378414
expvals_all,
379-
nothing, # Measurement expectation values
415+
m_expvals, # Measurement expectation values
380416
sol.converged,
381417
_sol_1.alg,
382418
_sol_1.prob.kwargs[:abstol],

src/time_evolution/time_evolution.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,13 +347,11 @@ function _stochastic_prob_func(prob, i, repeat, rng, seeds, tlist; kwargs...)
347347
traj_rng = typeof(rng)()
348348
seed!(traj_rng, seed)
349349

350-
store_measurement = haskey(kwargs, :store_measurement) ? getVal(kwargs[:store_measurement]) : false
351-
352350
noise = RealWienerProcess!(
353351
prob.prob.tspan[1],
354352
zeros(kwargs[:n_sc_ops]),
355353
zeros(kwargs[:n_sc_ops]),
356-
save_everystep = store_measurement,
354+
save_everystep = getVal(kwargs[:store_measurement]),
357355
rng = traj_rng,
358356
)
359357

0 commit comments

Comments
 (0)