Skip to content

Commit 95396ed

Browse files
Align mcsolve _get_expvals with the other solvers
1 parent 9e77d60 commit 95396ed

File tree

6 files changed

+48
-80
lines changed

6 files changed

+48
-80
lines changed

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ end
5151

5252
##
5353

54+
#=
55+
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
56+
This callback can only be a PresetTimeCallback (DiscreteCallback).
57+
=#
58+
function _get_e_ops(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc}
59+
cb = _get_save_callback(integrator, method)
60+
if cb isa Nothing
61+
return nothing
62+
else
63+
return cb.affect!.e_ops
64+
end
65+
end
66+
5467
# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`.
5568
function _get_expvals(sol::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
5669
cb = _get_save_callback(sol, method)
@@ -61,6 +74,11 @@ function _get_expvals(sol::AbstractODESolution, method::Type{SF}) where {SF<:Abs
6174
end
6275
end
6376

77+
#=
78+
_get_save_callback
79+
80+
Return the Callback that is responsible for saving the expectation values of the system.
81+
=#
6482
function _get_save_callback(sol::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
6583
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
6684
if hasproperty(kwargs, :callback)
@@ -74,19 +92,19 @@ _get_save_callback(integrator::AbstractODEIntegrator, method::Type{SF}) where {S
7492
function _get_save_callback(cb::CallbackSet, method::Type{SF}) where {SF<:AbstractSaveFunc}
7593
cbs_discrete = cb.discrete_callbacks
7694
if length(cbs_discrete) > 0
77-
idx = _get_save_callback_idx(method)
95+
idx = _get_save_callback_idx(cb, method)
7896
_cb = cb.discrete_callbacks[idx]
7997
return _get_save_callback(_cb, method)
8098
else
8199
return nothing
82100
end
83101
end
84-
function _get_save_callback(cb::DiscreteCallback, method::Type{SF}) where {SF<:AbstractSaveFunc}
85-
if typeof(cb.affect!) <: SF
102+
function _get_save_callback(cb::DiscreteCallback, ::Type{SF}) where {SF<:AbstractSaveFunc}
103+
if typeof(cb.affect!) <: AbstractSaveFunc
86104
return cb
87105
end
88106
return nothing
89107
end
90-
_get_save_callback(cb::ContinuousCallback, method::Type{SF}) where {SF<:AbstractSaveFunc} = nothing
108+
_get_save_callback(cb::ContinuousCallback, ::Type{SF}) where {SF<:AbstractSaveFunc} = nothing
91109

92-
_get_save_callback_idx(method) = 1
110+
_get_save_callback_idx(cb, method) = 1

src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ end
1010

1111
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals)
1212

13+
_get_save_callback_idx(cb, ::Type{SaveFuncMCSolve}) = _mcsolve_has_continuous_jump(cb) ? 1 : 2
14+
1315
##
1416
struct LindbladJump{
1517
T1,
@@ -168,37 +170,6 @@ _mcsolve_discrete_condition(u, t, integrator) =
168170

169171
##
170172

171-
#=
172-
_mc_get_save_callback
173-
174-
Return the Callback that is responsible for saving the expectation values of the system.
175-
=#
176-
function _mc_get_save_callback(sol::AbstractODESolution)
177-
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
178-
return _mc_get_save_callback(kwargs.callback) # There is always the Jump callback
179-
end
180-
_mc_get_save_callback(integrator::AbstractODEIntegrator) = _mc_get_save_callback(integrator.opts.callback)
181-
function _mc_get_save_callback(cb::CallbackSet)
182-
cbs_discrete = cb.discrete_callbacks
183-
184-
if length(cbs_discrete) > 0
185-
idx = _mcsolve_has_continuous_jump(cb) ? 1 : 2
186-
_cb = cb.discrete_callbacks[idx]
187-
return _mc_get_save_callback(_cb)
188-
else
189-
return nothing
190-
end
191-
end
192-
_mc_get_save_callback(cb::DiscreteCallback) =
193-
if cb.affect! isa SaveFuncMCSolve
194-
return cb
195-
else
196-
return nothing
197-
end
198-
_mc_get_save_callback(cb::ContinuousCallback) = nothing
199-
200-
##
201-
202173
function _mc_get_jump_callback(sol::AbstractODESolution)
203174
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
204175
return _mc_get_jump_callback(kwargs.callback) # There is always the Jump callback
@@ -216,8 +187,8 @@ _mc_get_jump_callback(cb::DiscreteCallback) = cb
216187
##
217188

218189
#=
219-
With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
220-
This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
190+
With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
191+
This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
221192
=#
222193
function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator)
223194
cb = _mc_get_jump_callback(integrator)
@@ -228,28 +199,6 @@ function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator)
228199
end
229200
end
230201

231-
#=
232-
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
233-
This callback can only be a PresetTimeCallback (DiscreteCallback).
234-
=#
235-
function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator)
236-
cb = _mc_get_save_callback(integrator)
237-
if cb isa Nothing
238-
return nothing
239-
else
240-
return cb.affect!.e_ops
241-
end
242-
end
243-
244-
function _mcsolve_get_expvals(sol::AbstractODESolution)
245-
cb = _mc_get_save_callback(sol)
246-
if cb isa Nothing
247-
return nothing
248-
else
249-
return cb.affect!.expvals
250-
end
251-
end
252-
253202
#=
254203
_mcsolve_initialize_callbacks(prob, tlist)
255204

src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414

1515
_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)
1616

17-
_get_save_callback_idx(method::SaveFuncSSESolve) = 2 # The first one is the normalization callback
17+
_get_save_callback_idx(cb, ::Type{SaveFuncSSESolve}) = 1 # The first one is the normalization callback
1818

1919
##
2020

src/time_evolution/mcsolve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,10 @@ function mcsolve(
398398

399399
dims = ens_prob_mc.dimensions
400400
_sol_1 = sol[:, 1]
401-
_expvals_sol_1 = _mcsolve_get_expvals(_sol_1)
401+
_expvals_sol_1 = _get_expvals(_sol_1, SaveFuncMCSolve)
402402

403-
_expvals_all = _expvals_sol_1 isa Nothing ? nothing : map(i -> _mcsolve_get_expvals(sol[:, i]), eachindex(sol))
403+
_expvals_all =
404+
_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_expvals(sol[:, i], SaveFuncMCSolve), eachindex(sol))
404405
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
405406
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
406407
col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))

src/time_evolution/time_evolution_dynamical.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ function _DSF_mcsolve_Affect!(integrator)
522522
# e_ops0 = params.e_ops
523523
# c_ops0 = params.c_ops
524524

525-
e_ops0 = _mcsolve_get_e_ops(integrator)
525+
e_ops0 = _get_e_ops(integrator, SaveFuncMCSEsolve)
526526
c_ops0, c_ops0_herm = _mcsolve_get_c_ops(integrator)
527527

528528
copyto!(ψt, integrator.u)

test/runtests.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,23 @@ const testdir = dirname(@__FILE__)
77

88
# Put core tests in alphabetical order
99
core_tests = [
10-
# "block_diagonal_form.jl",
11-
# "correlations_and_spectrum.jl",
12-
# "dynamical_fock_dimension_mesolve.jl",
13-
# "dynamical-shifted-fock.jl",
14-
# "eigenvalues_and_operators.jl",
15-
# "entanglement.jl",
16-
# "generalized_master_equation.jl",
17-
# "low_rank_dynamics.jl",
18-
# "negativity_and_partial_transpose.jl",
19-
# "progress_bar.jl",
20-
# "quantum_objects.jl",
21-
# "quantum_objects_evo.jl",
22-
# "states_and_operators.jl",
23-
# "steady_state.jl",
10+
"block_diagonal_form.jl",
11+
"correlations_and_spectrum.jl",
12+
"dynamical_fock_dimension_mesolve.jl",
13+
"dynamical-shifted-fock.jl",
14+
"eigenvalues_and_operators.jl",
15+
"entanglement.jl",
16+
"generalized_master_equation.jl",
17+
"low_rank_dynamics.jl",
18+
"negativity_and_partial_transpose.jl",
19+
"progress_bar.jl",
20+
"quantum_objects.jl",
21+
"quantum_objects_evo.jl",
22+
"states_and_operators.jl",
23+
"steady_state.jl",
2424
"time_evolution.jl",
25-
# "utilities.jl",
26-
# "wigner.jl",
25+
"utilities.jl",
26+
"wigner.jl",
2727
]
2828

2929
if (GROUP == "All") || (GROUP == "Core")

0 commit comments

Comments
 (0)