Skip to content

Commit dddb60b

Browse files
Remove expvals from TimeEvolutionParameters
1 parent b652fa7 commit dddb60b

File tree

9 files changed

+170
-119
lines changed

9 files changed

+170
-119
lines changed

src/QuantumToolbox.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ import SciMLBase:
3636
ContinuousCallback,
3737
DiscreteCallback,
3838
AbstractSciMLProblem,
39-
AbstractODEIntegrator
39+
AbstractODEIntegrator,
40+
AbstractODESolution
4041
import StochasticDiffEq: StochasticDiffEqAlgorithm, SRA1
4142
import SciMLOperators:
4243
SciMLOperators,

src/time_evolution/callback_helpers.jl

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ This file contains helper functions for callbacks. The affect! function are defi
44

55
########## SESOLVE ##########
66

7-
struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT}
7+
struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
88
e_ops::TE
99
progr::PT
1010
iter::IT
11+
expvals::TEXPV
1112
end
1213

13-
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter)
14+
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
1415
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func_sesolve(integrator, f.progr)
1516

1617
##
@@ -29,8 +30,7 @@ function _save_func_sesolve(integrator, progr::Nothing)
2930
end
3031

3132
# When e_ops is a list of operators
32-
function _save_func_sesolve(integrator, e_ops, progr, iter)
33-
expvals = integrator.p.expvals
33+
function _save_func_sesolve(integrator, e_ops, progr, iter, expvals)
3434
ψ = integrator.u
3535
_expect = op -> dot(ψ, op, ψ)
3636
@. expvals[:, iter[]] = _expect(e_ops)
@@ -44,18 +44,40 @@ function _generate_sesolve_callback(e_ops, tlist, progress_bar)
4444

4545
progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing
4646

47-
_save_affect! = SaveFuncSESolve(e_ops_data, progr, Ref(1))
48-
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
47+
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
48+
49+
_save_affect! = SaveFuncSESolve(e_ops_data, progr, Ref(1), expvals)
50+
return _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
4951
end
5052

53+
function _sesolve_get_expvals(sol::AbstractODESolution)
54+
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
55+
if hasproperty(kwargs, :callback)
56+
return _sesolve_get_expvals(kwargs.callback)
57+
else
58+
return nothing
59+
end
60+
end
61+
function _sesolve_get_expvals(cb::CallbackSet)
62+
_cb = cb.discrete_callbacks[1]
63+
return _sesolve_get_expvals(_cb)
64+
end
65+
_sesolve_get_expvals(cb::DiscreteCallback) = if cb.affect! isa SaveFuncSESolve
66+
return cb.affect!.expvals
67+
else
68+
return nothing
69+
end
70+
_sesolve_get_expvals(cb::ContinuousCallback) = nothing
71+
5172
########## MCSOLVE ##########
5273

53-
struct SaveFuncMCSolve{TE,IT}
74+
struct SaveFuncMCSolve{TE,IT,TEXPV}
5475
e_ops::TE
5576
iter::IT
77+
expvals::TEXPV
5678
end
5779

58-
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter)
80+
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals)
5981

6082
struct LindbladJump{T1,T2}
6183
c_ops::T1
@@ -66,8 +88,7 @@ end
6688

6789
##
6890

69-
function _save_func_mcsolve(integrator, e_ops, iter)
70-
expvals = integrator.p.expvals
91+
function _save_func_mcsolve(integrator, e_ops, iter, expvals)
7192
cache_mc = integrator.p.mcsolve_params.cache_mc
7293

7394
copyto!(cache_mc, integrator.u)
@@ -100,12 +121,15 @@ function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
100121
end
101122

102123
if e_ops isa Nothing
124+
# We are implicitly saying that we don't have a `ProgressBar`
103125
kwargs2 =
104126
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, kwargs.callback),)) :
105127
merge(kwargs, (callback = cb1,))
106128
return kwargs2
107129
else
108-
_save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1))
130+
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
131+
132+
_save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
109133
cb2 = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
110134
kwargs2 =
111135
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
@@ -178,35 +202,67 @@ function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator)
178202
return cb.affect!.e_ops
179203
end
180204

205+
function _mcsolve_get_expvals(sol::AbstractODESolution)
206+
cb = NamedTuple(sol.prob.kwargs).callback
207+
if _mcsolve_has_discrete_callbacks(cb)
208+
return _mcsolve_get_expvals(cb)
209+
else
210+
return nothing
211+
end
212+
end
213+
function _mcsolve_get_expvals(cb::CallbackSet)
214+
idx = _mcsolve_has_continuous_jump(cb) ? 1 : 2
215+
_cb = cb.discrete_callbacks[idx]
216+
return _mcsolve_get_expvals(_cb)
217+
end
218+
_mcsolve_get_expvals(cb::DiscreteCallback) =
219+
if cb.affect! isa SaveFuncMCSolve
220+
return cb.affect!.expvals
221+
else
222+
nothing
223+
end
224+
_mcsolve_get_expvals(cb::ContinuousCallback) = nothing
225+
181226
#=
182-
_mcsolve_callbacks_new_iter(prob, tlist)
227+
_mcsolve_callbacks_new_iter_expvals(prob, tlist)
183228
184-
Return the same callbacks of the `prob`, but with the `iter` variable reinitialized to 1.
229+
Return the same callbacks of the `prob`, but with the `iter` variable reinitialized to 1 and the `expvals` variable reinitialized to a new matrix.
185230
=#
186-
function _mcsolve_callbacks_new_iter(prob, tlist)
231+
function _mcsolve_callbacks_new_iter_expvals(prob, tlist)
187232
cb = prob.kwargs[:callback]
188-
return _mcsolve_callbacks_new_iter(cb, tlist)
233+
return _mcsolve_callbacks_new_iter_expvals(cb, tlist)
189234
end
190-
function _mcsolve_callbacks_new_iter(cb::CallbackSet, tlist)
235+
function _mcsolve_callbacks_new_iter_expvals(cb::CallbackSet, tlist)
191236
cb_continuous = cb.continuous_callbacks
192237
cb_discrete = cb.discrete_callbacks
193238

194-
if length(cb_continuous) > 0
239+
if _mcsolve_has_continuous_jump(cb)
195240
idx = 1
196241
e_ops = cb_discrete[idx].affect!.e_ops
197-
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1))
242+
expvals = similar(cb_discrete[idx].affect!.expvals)
243+
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
198244
cb_save = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
199245
return CallbackSet(cb_continuous..., cb_save, cb_discrete[2:end]...)
200246
else
201247
idx = 2
202248
e_ops = cb_discrete[idx].affect!.e_ops
203-
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1))
249+
expvals = similar(cb_discrete[idx].affect!.expvals)
250+
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
204251
cb_save = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
205252
return CallbackSet(cb_continuous..., cb_discrete[1], cb_save, cb_discrete[3:end]...)
206253
end
207254
end
208-
_mcsolve_callbacks_new_iter(cb::ContinuousCallback, tlist) = cb
209-
_mcsolve_callbacks_new_iter(cb::DiscreteCallback, tlist) = cb
255+
_mcsolve_callbacks_new_iter_expvals(cb::ContinuousCallback, tlist) = cb # It is only the continuous LindbladJump callback
256+
_mcsolve_callbacks_new_iter_expvals(cb::DiscreteCallback, tlist) = cb # It is only the discrete LindbladJump callback
257+
258+
_mcsolve_has_discrete_callbacks(cb::CallbackSet) = length(cb.discrete_callbacks) > 0
259+
_mcsolve_has_discrete_callbacks(cb::ContinuousCallback) = false
260+
_mcsolve_has_discrete_callbacks(cb::DiscreteCallback) = true
261+
262+
_mcsolve_has_continuous_jump(cb::CallbackSet) =
263+
(length(cb.continuous_callbacks) > 0) && (cb.continuous_callbacks[1].affect! isa LindbladJump)
264+
_mcsolve_has_continuous_jump(cb::ContinuousCallback) = true
265+
_mcsolve_has_continuous_jump(cb::DiscreteCallback) = false
210266

211267
## Temporary function to avoid errors. Waiting for the PR In DiffEqCallbacks.jl to be merged.
212268

src/time_evolution/mcsolve.jl

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ function _mcsolve_prob_func(prob, i, repeat, global_rng, seeds, tlist)
88
traj_rng = typeof(global_rng)()
99
seed!(traj_rng, seed)
1010

11-
expvals = similar(params.expvals)
12-
13-
T = eltype(expvals)
11+
T = eltype(prob.u0)
1412

1513
mcsolve_params = (
1614
traj_rng = traj_rng,
@@ -23,10 +21,10 @@ function _mcsolve_prob_func(prob, i, repeat, global_rng, seeds, tlist)
2321
jump_times_which_idx = T[1],
2422
)
2523

26-
p = TimeEvolutionParameters(params.params, expvals, mcsolve_params)
24+
p = TimeEvolutionParameters(params.params, mcsolve_params)
2725

2826
f = deepcopy(prob.f.f)
29-
cb = _mcsolve_callbacks_new_iter(prob, tlist)
27+
cb = _mcsolve_callbacks_new_iter_expvals(prob, tlist)
3028

3129
return remake(prob, f = f, p = p, callback = cb)
3230
end
@@ -106,7 +104,7 @@ end
106104
tlist::AbstractVector,
107105
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
108106
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
109-
params::Union{NamedTuple,AbstractVector} = eltype(ψ0)[],
107+
params = eltype(ψ0)[],
110108
rng::AbstractRNG = default_rng(),
111109
jump_callback::TJC = ContinuousLindbladJumpCallback(),
112110
kwargs...,
@@ -153,7 +151,7 @@ If the environmental measurements register a quantum jump, the wave function und
153151
- `tlist`: List of times at which to save either the state or the expectation values of the system.
154152
- `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`.
155153
- `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`.
156-
- `params`: `NamedTuple` or `AbstractVector` of parameters to pass to the solver.
154+
- `params`: `NamedTuple` or `AbstractVector` of parameters to pass to the solver. For more advanced usage, any custom struct can be used.
157155
- `rng`: Random number generator for reproducibility.
158156
- `jump_callback`: The Jump Callback type: Discrete or Continuous. The default is `ContinuousLindbladJumpCallback()`, which is more precise.
159157
- `kwargs`: The keyword arguments for the ODEProblem.
@@ -175,7 +173,7 @@ function mcsolveProblem(
175173
tlist::AbstractVector,
176174
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
177175
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
178-
params::Union{NamedTuple,AbstractVector} = eltype(ψ0)[],
176+
params = eltype(ψ0)[],
179177
rng::AbstractRNG = default_rng(),
180178
jump_callback::TJC = ContinuousLindbladJumpCallback(),
181179
kwargs...,
@@ -192,13 +190,7 @@ function mcsolveProblem(
192190

193191
T = Base.promote_eltype(H_eff_evo, ψ0)
194192

195-
if e_ops isa Nothing
196-
expvals = Array{T}(undef, 0, length(tlist))
197-
is_empty_e_ops = true
198-
else
199-
expvals = Array{T}(undef, length(e_ops), length(tlist))
200-
is_empty_e_ops = isempty(e_ops)
201-
end
193+
is_empty_e_ops = e_ops isa Nothing ? true : isempty(e_ops)
202194

203195
saveat = is_empty_e_ops ? tlist : [tlist[end]]
204196
# We disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
@@ -217,8 +209,6 @@ function mcsolveProblem(
217209
random_n = similar(ψ0.data, T, 1) # We could use a Ref, but we have to keep the same type for all the parameters due to SciMLStructures.jl.
218210
random_n[1] = rand(rng)
219211

220-
progr = ProgressBar(length(tlist), enable = false)
221-
222212
mcsolve_params = (
223213
traj_rng = rng,
224214
random_n = random_n,
@@ -229,7 +219,7 @@ function mcsolveProblem(
229219
jump_which = jump_which,
230220
jump_times_which_idx = jump_times_which_idx,
231221
)
232-
p = TimeEvolutionParameters(params, expvals, mcsolve_params)
222+
p = TimeEvolutionParameters(params, mcsolve_params)
233223

234224
return sesolveProblem(H_eff_evo, ψ0, tlist; params = p, kwargs3...)
235225
end
@@ -524,13 +514,15 @@ function mcsolve(
524514

525515
dims = ens_prob_mc.dims
526516
_sol_1 = sol[:, 1]
517+
_expvals_sol_1 = _mcsolve_get_expvals(_sol_1)
527518

528-
expvals_all = mapreduce(i -> sol[:, i].prob.p.expvals, (x, y) -> cat(x, y, dims = 3), eachindex(sol))
519+
_expvals_all = _expvals_sol_1 isa Nothing ? nothing : map(i -> _mcsolve_get_expvals(sol[:, i]), eachindex(sol))
520+
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all)
529521
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
530522
jump_times = map(i -> real.(sol[:, i].prob.p.mcsolve_params.jump_times), eachindex(sol))
531523
jump_which = map(i -> round.(Int, sol[:, i].prob.p.mcsolve_params.jump_which), eachindex(sol))
532524

533-
expvals = dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
525+
expvals = _expvals_sol_1 isa Nothing ? nothing : dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
534526

535527
return TimeEvolutionMCSol(
536528
ntraj,

src/time_evolution/sesolve.jl

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export sesolveProblem, sesolve
22

33
function _merge_sesolve_kwargs_with_callback(kwargs, cb)
44
kwargs2 =
5-
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb),)) :
5+
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) :
66
merge(kwargs, (callback = cb,))
77

88
return kwargs2
@@ -25,7 +25,7 @@ _sesolve_make_U_QobjEvo(H) = QobjEvo(H, -1im)
2525
ψ0::QuantumObject{DT2,KetQuantumObject},
2626
tlist::AbstractVector;
2727
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
28-
params::Union{NamedTuple, AbstractVector, TimeEvolutionParameters} = eltype(ψ0)[],
28+
params = eltype(ψ0)[],
2929
progress_bar::Union{Val,Bool} = Val(true),
3030
inplace::Union{Val,Bool} = Val(true),
3131
kwargs...,
@@ -43,7 +43,7 @@ Generate the ODEProblem for the Schrödinger time evolution of a quantum system:
4343
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``.
4444
- `tlist`: List of times at which to save either the state or the expectation values of the system.
4545
- `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`.
46-
- `params`: `NamedTuple` or `AbstractVector` of parameters to pass to the solver. For more advanced usage, you can use the [`TimeEvolutionParameters`](@ref) struct.
46+
- `params`: `NamedTuple` or `AbstractVector` of parameters to pass to the solver. For more advanced usage, any custom struct can be used.
4747
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
4848
- `inplace`: Whether to use the inplace version of the ODEProblem. The default is `Val(true)`.
4949
- `kwargs`: The keyword arguments for the ODEProblem.
@@ -64,7 +64,7 @@ function sesolveProblem(
6464
ψ0::QuantumObject{DT2,KetQuantumObject},
6565
tlist::AbstractVector;
6666
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
67-
params::Union{NamedTuple,AbstractVector,TimeEvolutionParameters} = eltype(ψ0)[],
67+
params = eltype(ψ0)[],
6868
progress_bar::Union{Val,Bool} = Val(true),
6969
inplace::Union{Val,Bool} = Val(true),
7070
kwargs...,
@@ -78,34 +78,19 @@ function sesolveProblem(
7878
isoper(H_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
7979
check_dims(H_evo, ψ0)
8080

81-
ψ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type
81+
T = Base.promote_eltype(H_evo, ψ0)
82+
ψ0 = sparse_to_dense(_CType(T), get_data(ψ0)) # Convert it to dense vector with complex element type
8283
U = H_evo.data
8384

84-
if e_ops isa Nothing
85-
expvals = Array{ComplexF64}(undef, 0, length(tlist))
86-
is_empty_e_ops = true
87-
else
88-
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
89-
is_empty_e_ops = isempty(e_ops)
90-
end
91-
92-
if params isa TimeEvolutionParameters
93-
(e_ops isa Nothing) || throw(
94-
ArgumentError(
95-
"The parameter `params` cannot be a TimeEvolutionParameters object when `e_ops` is not Nothing",
96-
),
97-
)
98-
end
99-
100-
p = params isa TimeEvolutionParameters ? params : TimeEvolutionParameters(params, expvals)
85+
is_empty_e_ops = (e_ops isa Nothing) ? true : isempty(e_ops)
10186

10287
saveat = is_empty_e_ops ? tlist : [tlist[end]]
10388
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
10489
kwargs2 = merge(default_values, kwargs)
10590
kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2)
10691

10792
tspan = (tlist[1], tlist[end])
108-
prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, p; kwargs3...)
93+
prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs3...)
10994

11095
return TimeEvolutionProblem(prob, tlist, H_evo.dims)
11196
end
@@ -117,7 +102,7 @@ end
117102
tlist::AbstractVector;
118103
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
119104
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
120-
params::Union{NamedTuple, AbstractVector} = eltype(ψ0)[],
105+
params = eltype(ψ0)[],
121106
progress_bar::Union{Val,Bool} = Val(true),
122107
inplace::Union{Val,Bool} = Val(true),
123108
kwargs...,
@@ -136,7 +121,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
136121
- `tlist`: List of times at which to save either the state or the expectation values of the system.
137122
- `alg`: The algorithm for the ODE solver. The default is `Tsit5()`.
138123
- `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`.
139-
- `params`: `NamedTuple` or `AbstractVector` of parameters to pass to the solver.
124+
- `params`: `NamedTuple` or `AbstractVector` of parameters to pass to the solver. For more advanced usage, any custom struct can be used.
140125
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
141126
- `inplace`: Whether to use the inplace version of the ODEProblem. The default is `Val(true)`.
142127
- `kwargs`: The keyword arguments for the ODEProblem.
@@ -159,7 +144,7 @@ function sesolve(
159144
tlist::AbstractVector;
160145
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
161146
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
162-
params::Union{NamedTuple,AbstractVector} = eltype(ψ0)[],
147+
params = eltype(ψ0)[],
163148
progress_bar::Union{Val,Bool} = Val(true),
164149
inplace::Union{Val,Bool} = Val(true),
165150
kwargs...,
@@ -186,7 +171,7 @@ function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit
186171
return TimeEvolutionSol(
187172
prob.times,
188173
ψt,
189-
sol.prob.p.expvals,
174+
_sesolve_get_expvals(sol),
190175
sol.retcode,
191176
sol.alg,
192177
NamedTuple(sol.prob.kwargs).abstol,

0 commit comments

Comments
 (0)