Skip to content

Commit 2836696

Browse files
Make time evolution solvers compatible with automatic differentiation (#311)
* Working sesolve * add `inplace` keywork argument * add SciMLStructures and relax params type * Working mcsolve (no type-stability) * Fix type-instabilities for mcsolve * Add SciMLStructures.jl methods * Add callbacks helpers * Fix dsf_mcsolve * Remove ProgressBar from ODE parameters * Fix abstol and reltol extraction * Use Base allequal function * Remove expvals from TimeEvolutionParameters * Make NullParameters as default for params * Remove custom PresetTimeCallback * Update description of `inplace` argument * Working mesolve * Fix dfd_mesolve and dsf_mesolve * Remove TimeEvolutionParameters (type-unstable) * Fix type instabilities * Fix type instabilities on Julia v1.10 * Format document
1 parent 9567c45 commit 2836696

17 files changed

+876
-561
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10-
- *__We will start to write changelog once we have the first standard release.__*
10+
- Change the parameters structure of `sesolve`, `mesolve` and `mcsolve` functions to possibly support automatic differentiation. ([#311])
11+
1112

1213
## [v0.21.5] (2024-11-15)
1314

@@ -21,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2122
<!-- Links generated by Changelog.jl -->
2223

2324
[v0.21.4]: https://github.com/qutip/QuantumToolbox.jl/releases/tag/v0.21.4
25+
[v0.21.5]: https://github.com/qutip/QuantumToolbox.jl/releases/tag/v0.21.5
2426
[#139]: https://github.com/qutip/QuantumToolbox.jl/issues/139
2527
[#306]: https://github.com/qutip/QuantumToolbox.jl/issues/306
2628
[#309]: https://github.com/qutip/QuantumToolbox.jl/issues/309
29+
[#311]: https://github.com/qutip/QuantumToolbox.jl/issues/311

docs/src/resources/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ qeye
181181
## [Time evolution](@id doc-API:Time-evolution)
182182

183183
```@docs
184+
TimeEvolutionProblem
184185
TimeEvolutionSol
185186
TimeEvolutionMCSol
186187
TimeEvolutionSSESol

src/QuantumToolbox.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,22 @@ import SciMLBase:
2323
reinit!,
2424
remake,
2525
u_modified!,
26+
NullParameters,
2627
ODEFunction,
2728
ODEProblem,
2829
SDEProblem,
2930
EnsembleProblem,
3031
EnsembleSerial,
3132
EnsembleThreads,
33+
EnsembleSplitThreads,
3234
EnsembleDistributed,
3335
FullSpecialize,
3436
CallbackSet,
3537
ContinuousCallback,
36-
DiscreteCallback
38+
DiscreteCallback,
39+
AbstractSciMLProblem,
40+
AbstractODEIntegrator,
41+
AbstractODESolution
3742
import StochasticDiffEq: StochasticDiffEqAlgorithm, SRA1
3843
import SciMLOperators:
3944
SciMLOperators,
@@ -88,6 +93,10 @@ include("qobj/synonyms.jl")
8893

8994
# time evolution
9095
include("time_evolution/time_evolution.jl")
96+
include("time_evolution/callback_helpers/sesolve_callback_helpers.jl")
97+
include("time_evolution/callback_helpers/mesolve_callback_helpers.jl")
98+
include("time_evolution/callback_helpers/mcsolve_callback_helpers.jl")
99+
include("time_evolution/callback_helpers/callback_helpers.jl")
91100
include("time_evolution/mesolve.jl")
92101
include("time_evolution/lr_mesolve.jl")
93102
include("time_evolution/sesolve.jl")

src/correlations.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ function correlation_3op_2t(
4949
(H.dims == ψ0.dims && H.dims == A.dims && H.dims == B.dims && H.dims == C.dims) ||
5050
throw(DimensionMismatch("The quantum objects are not of the same Hilbert dimension."))
5151

52-
kwargs2 = (; kwargs...)
53-
kwargs2 = merge(kwargs2, (saveat = collect(t_l),))
52+
kwargs2 = merge((saveat = collect(t_l),), (; kwargs...))
5453
ρt = mesolve(H, ψ0, t_l, c_ops; kwargs2...).states
5554

5655
corr = map((t, ρ) -> mesolve(H, C * ρ * A, τ_l .+ t, c_ops, e_ops = [B]; kwargs...).expect[1, :], t_l, ρt)

src/qobj/eigsolve.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,15 @@ function eigsolve_al(
391391
kwargs...,
392392
) where {DT1,HOpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject}}
393393
L_evo = _mesolve_make_L_QobjEvo(H, c_ops)
394-
prob = mesolveProblem(
395-
L_evo,
396-
QuantumObject(ρ0, type = Operator, dims = H.dims),
397-
[zero(T), T];
398-
params = params,
399-
progress_bar = Val(false),
400-
kwargs...,
401-
)
394+
prob =
395+
mesolveProblem(
396+
L_evo,
397+
QuantumObject(ρ0, type = Operator, dims = H.dims),
398+
[zero(T), T];
399+
params = params,
400+
progress_bar = Val(false),
401+
kwargs...,
402+
).prob
402403
integrator = init(prob, alg)
403404

404405
# prog = ProgressUnknown(desc="Applications:", showspeed = true, enabled=progress)

src/qobj/quantum_object_evo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution`
269269
quote
270270
dims = tuple($(dims_expr...))
271271

272-
length(unique(dims)) == 1 || throw(ArgumentError("The dimensions of the operators must be the same."))
272+
allequal(dims) || throw(ArgumentError("The dimensions of the operators must be the same."))
273273

274274
data_expr_const = $qobj_expr_const isa Integer ? $qobj_expr_const : _make_SciMLOperator($qobj_expr_const, α)
275275

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#=
2+
This file contains helper functions for callbacks. The affect! function are defined taking advantage of the Julia struct, which allows to store some cache exclusively for the callback.
3+
=#
4+
5+
##
6+
7+
# Multiple dispatch depending on the progress_bar and e_ops types
8+
function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method)
9+
cb = _generate_save_callback(e_ops, tlist, progress_bar, method)
10+
return _merge_kwargs_with_callback(kwargs, cb)
11+
end
12+
_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs
13+
14+
function _merge_kwargs_with_callback(kwargs, cb)
15+
kwargs2 =
16+
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) :
17+
merge(kwargs, (callback = cb,))
18+
19+
return kwargs2
20+
end
21+
22+
function _generate_save_callback(e_ops, tlist, progress_bar, method)
23+
e_ops_data = e_ops isa Nothing ? nothing : _get_e_ops_data(e_ops, method)
24+
25+
progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing
26+
27+
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
28+
29+
_save_affect! = method(e_ops_data, progr, Ref(1), expvals)
30+
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
31+
end
32+
33+
_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)
34+
_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10
35+
36+
_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op)))
37+
38+
##
39+
40+
# When e_ops is Nothing. Common for both mesolve and sesolve
41+
function _save_func(integrator, progr)
42+
next!(progr)
43+
u_modified!(integrator, false)
44+
return nothing
45+
end
46+
47+
# When progr is Nothing. Common for both mesolve and sesolve
48+
function _save_func(integrator, progr::Nothing)
49+
u_modified!(integrator, false)
50+
return nothing
51+
end
52+
53+
##
54+
55+
# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`.
56+
function _se_me_sse_get_expvals(sol::AbstractODESolution)
57+
cb = _se_me_sse_get_save_callback(sol)
58+
if cb isa Nothing
59+
return nothing
60+
else
61+
return cb.affect!.expvals
62+
end
63+
end
64+
65+
function _se_me_sse_get_save_callback(sol::AbstractODESolution)
66+
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
67+
if hasproperty(kwargs, :callback)
68+
return _se_me_sse_get_save_callback(kwargs.callback)
69+
else
70+
return nothing
71+
end
72+
end
73+
_se_me_sse_get_save_callback(integrator::AbstractODEIntegrator) = _se_me_sse_get_save_callback(integrator.opts.callback)
74+
function _se_me_sse_get_save_callback(cb::CallbackSet)
75+
cbs_discrete = cb.discrete_callbacks
76+
if length(cbs_discrete) > 0
77+
_cb = cb.discrete_callbacks[1]
78+
return _se_me_sse_get_save_callback(_cb)
79+
else
80+
return nothing
81+
end
82+
end
83+
_se_me_sse_get_save_callback(cb::DiscreteCallback) =
84+
if (cb.affect! isa SaveFuncSESolve) || (cb.affect! isa SaveFuncMESolve)
85+
return cb
86+
else
87+
return nothing
88+
end
89+
_se_me_sse_get_save_callback(cb::ContinuousCallback) = nothing

0 commit comments

Comments
 (0)