Skip to content

Commit e8a550e

Browse files
authored
Improve accuracy of ODE solvers for general cases (#586)
1 parent c0270e7 commit e8a550e

File tree

16 files changed

+159
-93
lines changed

16 files changed

+159
-93
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
- Change default solver detection in `eigensolve` when using `sigma` keyword argument (shift-inverse algorithm). If the operator is a `SparseMatrixCSC`, the default solver is `UMFPACKFactorization`, otherwise it is automatically chosen by LinearSolve.jl, depending on the type of the operator. ([#580])
1111
- Add keyword argument `assume_hermitian` to `liouvillian`. This allows users to disable the assumption that the Hamiltonian is Hermitian. ([#581])
12+
- Improve accuracy of ODE solvers for general cases. ([#586])
1213
- Use LinearSolve's internal methods for preconditioners in `SteadyStateLinearSolver`. ([#588])
1314
- Use `FillArrays.jl` for handling superoperators. This makes the code cleaner and potentially more efficient. ([#589])
1415
- Make sure state generating functions return dense array by default. ([#591])
@@ -365,6 +366,7 @@ Release date: 2024-11-13
365366
[#579]: https://github.com/qutip/QuantumToolbox.jl/issues/579
366367
[#580]: https://github.com/qutip/QuantumToolbox.jl/issues/580
367368
[#581]: https://github.com/qutip/QuantumToolbox.jl/issues/581
369+
[#586]: https://github.com/qutip/QuantumToolbox.jl/issues/586
368370
[#588]: https://github.com/qutip/QuantumToolbox.jl/issues/588
369371
[#589]: https://github.com/qutip/QuantumToolbox.jl/issues/589
370372
[#591]: https://github.com/qutip/QuantumToolbox.jl/issues/591

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ authors = ["Alberto Mercurio", "Yi-Te Huang"]
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
98
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
109
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
1110
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -17,7 +16,8 @@ LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
1716
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1817
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1918
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
20-
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
19+
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
20+
OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2"
2121
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2222
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
2323
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -46,7 +46,6 @@ QuantumToolboxMakieExt = "Makie"
4646
ArrayInterface = "6, 7"
4747
CUDA = "5.0 - 5.8, 5.9.4 - 5"
4848
ChainRulesCore = "1"
49-
DiffEqBase = "6"
5049
DiffEqCallbacks = "4.2.1 - 4"
5150
DiffEqNoiseProcess = "5"
5251
Distributed = "1"
@@ -61,7 +60,8 @@ LinearAlgebra = "1"
6160
LinearSolve = "2, 3"
6261
Makie = "0.24"
6362
OrdinaryDiffEqCore = "1"
64-
OrdinaryDiffEqTsit5 = "1"
63+
OrdinaryDiffEqLowOrderRK = "1"
64+
OrdinaryDiffEqVerner = "1"
6565
Pkg = "1"
6666
ProgressMeter = "1.11.0"
6767
Random = "1"

src/QuantumToolbox.jl

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
module QuantumToolbox
22

3-
# Standard Julia libraries
3+
## Standard Julia libraries
44
using LinearAlgebra
5-
import LinearAlgebra: checksquare
65
using SparseArrays
6+
7+
import Distributed: RemoteChannel
8+
import LinearAlgebra: checksquare
9+
import Pkg
10+
import Random: AbstractRNG, default_rng, seed!
711
import Statistics: mean, std
812

9-
# SciML packages (for QobjEvo, OrdinaryDiffEq, and LinearSolve)
13+
## SciML packages (for QobjEvo, OrdinaryDiffEq, and LinearSolve)
1014
import SciMLBase:
1115
solve,
1216
solve!,
@@ -32,8 +36,10 @@ import SciMLBase:
3236
DiscreteCallback,
3337
AbstractSciMLProblem,
3438
AbstractODEIntegrator,
35-
AbstractODESolution
36-
import StochasticDiffEq: StochasticDiffEqAlgorithm, SRA2, SRIW1
39+
AbstractODEAlgorithm,
40+
AbstractODESolution,
41+
AbstractSDEAlgorithm
42+
import StochasticDiffEq: SRA2, SRIW1
3743
import SciMLOperators:
3844
cache_operator,
3945
iscached,
@@ -49,44 +55,42 @@ import SciMLOperators:
4955
concretize
5056
import LinearSolve:
5157
SciMLLinearSolveAlgorithm, KrylovJL_MINRES, KrylovJL_GMRES, UMFPACKFactorization, OperatorAssumptions
52-
import DiffEqBase: get_tstops
5358
import DiffEqCallbacks: PeriodicCallback, FunctionCallingCallback, FunctionCallingAffect, TerminateSteadyState
54-
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm
55-
import OrdinaryDiffEqTsit5: Tsit5
59+
import OrdinaryDiffEqVerner: Vern7
60+
import OrdinaryDiffEqLowOrderRK: DP5
5661
import DiffEqNoiseProcess: RealWienerProcess!, RealWienerProcess
5762

58-
# other dependencies (in alphabetical order)
63+
## other dependencies (in alphabetical order)
5964
import ArrayInterface: allowed_getindex, allowed_setindex!
60-
import Distributed: RemoteChannel
6165
import FFTW: fft, ifft, fftfreq, fftshift
6266
import FillArrays: Eye
6367
import Graphs: connected_components, DiGraph
6468
import IncompleteLU: ilu
6569
import LaTeXStrings: @L_str
66-
import Pkg
6770
import ProgressMeter: Progress, next!
68-
import Random: AbstractRNG, default_rng, seed!
6971
import SpecialFunctions: loggamma
7072
import StaticArraysCore: SVector, MVector
7173

7274
# Export functions from the other modules
7375

74-
# LinearAlgebra
76+
## LinearAlgebra
7577
export ishermitian, issymmetric, isposdef, dot, tr, svdvals, norm, normalize, normalize!, diag, Hermitian, Symmetric
7678

77-
# SparseArrays
79+
## SparseArrays
7880
export permute
7981

80-
# SciMLOperators
82+
## SciMLOperators
8183
export cache_operator, iscached, isconstant
8284

83-
# Utility
85+
# Source files
86+
87+
## Utility
8488
include("settings.jl")
8589
include("utilities.jl")
8690
include("versioninfo.jl")
8791
include("linear_maps.jl")
8892

89-
# Quantum Object
93+
## Quantum Object
9094
include("qobj/space.jl")
9195
include("qobj/energy_restricted.jl")
9296
include("qobj/dimensions.jl")
@@ -103,7 +107,7 @@ include("qobj/superoperators.jl")
103107
include("qobj/synonyms.jl")
104108
include("qobj/block_diagonal_form.jl")
105109

106-
# time evolution
110+
## time evolution
107111
include("time_evolution/time_evolution.jl")
108112
include("time_evolution/callback_helpers/callback_helpers.jl")
109113
include("time_evolution/callback_helpers/sesolve_callback_helpers.jl")
@@ -120,7 +124,7 @@ include("time_evolution/ssesolve.jl")
120124
include("time_evolution/smesolve.jl")
121125
include("time_evolution/time_evolution_dynamical.jl")
122126

123-
# Others
127+
## Others
124128
include("correlations.jl")
125129
include("wigner.jl")
126130
include("spin_lattice.jl")
@@ -132,7 +136,7 @@ include("steadystate.jl")
132136
include("spectrum.jl")
133137
include("visualization.jl")
134138

135-
# deprecated functions
139+
## deprecated functions
136140
include("deprecated.jl")
137141

138142
end

src/qobj/eigsolve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ end
422422
H::Union{AbstractQuantumObject{HOpType},Tuple},
423423
T::Real,
424424
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
425-
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
425+
alg::AbstractODEAlgorithm = DP5(),
426426
params::NamedTuple = NamedTuple(),
427427
ρ0::AbstractMatrix = rand_dm(prod(H.dimensions)).data,
428428
eigvals::Int = 1,
@@ -440,7 +440,7 @@ Solve the eigenvalue problem for a Liouvillian superoperator `L` using the Arnol
440440
- `H`: The Hamiltonian (or directly the Liouvillian) of the system. It can be a [`QuantumObject`](@ref), a [`QuantumObjectEvolution`](@ref), or a tuple of the form supported by [`mesolve`](@ref).
441441
- `T`: The time at which to evaluate the time evolution.
442442
- `c_ops`: A vector of collapse operators. Default is `nothing` meaning the system is closed.
443-
- `alg`: The differential equation solver algorithm. Default is `Tsit5()`.
443+
- `alg`: The differential equation solver algorithm. Default is `DP5()`.
444444
- `params`: A `NamedTuple` containing the parameters of the system.
445445
- `ρ0`: The initial density matrix. If not specified, a random density matrix is used.
446446
- `eigvals`: The number of eigenvalues to compute.
@@ -465,7 +465,7 @@ function eigsolve_al(
465465
H::Union{AbstractQuantumObject{HOpType},Tuple},
466466
T::Real,
467467
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
468-
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
468+
alg::AbstractODEAlgorithm = DP5(),
469469
params::NamedTuple = NamedTuple(),
470470
ρ0::AbstractMatrix = rand_dm(prod(H.dimensions)).data,
471471
eigvals::Int = 1,

src/steadystate.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ end
4141

4242
@doc raw"""
4343
SteadyStateODESolver(
44-
alg = Tsit5(),
44+
alg = DP5(),
4545
ψ0 = nothing,
4646
tmax = Inf,
4747
terminate_reltol = 1e-4,
@@ -63,7 +63,7 @@ or
6363
```
6464
6565
# Arguments
66-
- `alg::OrdinaryDiffEqAlgorithm=Tsit5()`: The algorithm to solve the ODE.
66+
- `alg::AbstractODEAlgorithm=DP5()`: The algorithm to solve the ODE.
6767
- `ψ0::Union{Nothing,QuantumObject}=nothing`: The initial state of the system. If not specified, a random pure state will be generated.
6868
- `tmax::Real=Inf`: The final time step for the steady state problem.
6969
- `terminate_reltol` = The relative tolerance for stationary state terminate condition. Default to `1e-4`.
@@ -75,13 +75,13 @@ or
7575
For more details about the solving `alg`orithms, please refer to [`OrdinaryDiffEq.jl`](https://docs.sciml.ai/OrdinaryDiffEq/stable/).
7676
"""
7777
Base.@kwdef struct SteadyStateODESolver{
78-
MT<:OrdinaryDiffEqAlgorithm,
78+
MT<:AbstractODEAlgorithm,
7979
ST<:Union{Nothing,QuantumObject},
8080
TT<:Real,
8181
RT<:Real,
8282
AT<:Real,
8383
} <: SteadyStateSolver
84-
alg::MT = Tsit5()
84+
alg::MT = DP5()
8585
ψ0::ST = nothing
8686
tmax::TT = Inf
8787
terminate_reltol::RT = 1e-4

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ This file contains helper functions for callbacks. The affect! function are defi
66

77
abstract type AbstractSaveFunc end
88

9+
function _merge_tstops(kwargs, prob_is_const::Bool, tlist)
10+
if prob_is_const
11+
return kwargs
12+
else
13+
tstops = haskey(kwargs, :tstops) ? unique!(sort!(vcat(tlist, kwargs.tstops))) : tlist
14+
return merge(kwargs, (tstops = tstops,))
15+
end
16+
end
17+
918
# Multiple dispatch depending on the progress_bar and e_ops types
1019
function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method)
1120
cb = _generate_save_callback(e_ops, tlist, progress_bar, method)
@@ -26,8 +35,7 @@ function _generate_stochastic_kwargs(
2635

2736
# Ensure that the noise is stored in tlist. # TODO: Fix this directly in DiffEqNoiseProcess.jl
2837
# See https://github.com/SciML/DiffEqNoiseProcess.jl/issues/214 for example
29-
tstops = haskey(kwargs, :tstops) ? unique!(sort!(vcat(tlist, kwargs.tstops))) : tlist
30-
kwargs2 = merge(kwargs, (tstops = tstops,))
38+
kwargs2 = _merge_tstops(kwargs, false, tlist) # set 'prob_is_const = false' to force add 'tstops = tlist'
3139

3240
if SF === SaveFuncSSESolve
3341
cb_normalize = _ssesolve_generate_normalize_cb()

src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,15 @@ function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rn
107107

108108
if e_ops isa Nothing
109109
# We are implicitly saying that we don't have a `Progress`
110-
kwargs2 =
111-
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, kwargs.callback),)) :
112-
merge(kwargs, (callback = cb1,))
113-
return kwargs2
110+
kwargs2 = _merge_kwargs_with_callback(kwargs, cb1)
114111
else
115112
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
116113

117114
_save_func = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
118115
cb2 = FunctionCallingCallback(_save_func, funcat = tlist)
119-
kwargs2 =
120-
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
121-
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
122-
return kwargs2
116+
kwargs2 = _merge_kwargs_with_callback(kwargs, CallbackSet(cb1, cb2))
123117
end
118+
return kwargs2
124119
end
125120

126121
function _lindblad_jump_affect!(

src/time_evolution/lr_mesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct TimeEvolutionLRSol{
2525
TS<:AbstractVector,
2626
TE<:Matrix{ComplexF64},
2727
RetT<:Enum,
28-
AlgT<:OrdinaryDiffEqAlgorithm,
28+
AlgT<:AbstractODEAlgorithm,
2929
TolT<:Real,
3030
TSZB<:AbstractVector,
3131
TM<:Vector{<:Integer},
@@ -45,7 +45,7 @@ struct TimeEvolutionLRSol{
4545
end
4646

4747
lr_mesolve_options_default = (
48-
alg = Tsit5(),
48+
alg = DP5(),
4949
progress = true,
5050
err_max = 0.0,
5151
p0 = 0.0,

src/time_evolution/mcsolve.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ end
274274
ψ0::QuantumObject{Ket},
275275
tlist::AbstractVector,
276276
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
277-
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
277+
alg::AbstractODEAlgorithm = DP5(),
278278
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
279279
params = NullParameters(),
280280
rng::AbstractRNG = default_rng(),
@@ -329,7 +329,7 @@ If the environmental measurements register a quantum jump, the wave function und
329329
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``.
330330
- `tlist`: List of time points at which to save either the state or the expectation values of the system.
331331
- `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`.
332-
- `alg`: The algorithm to use for the ODE solver. Default to `Tsit5()`.
332+
- `alg`: The algorithm to use for the ODE solver. Default to `DP5()`.
333333
- `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`.
334334
- `params`: Parameters to pass to the solver. This argument is usually expressed as a `NamedTuple` or `AbstractVector` of parameters. For more advanced usage, any custom struct can be used.
335335
- `rng`: Random number generator for reproducibility.
@@ -361,7 +361,7 @@ function mcsolve(
361361
ψ0::QuantumObject{Ket},
362362
tlist::AbstractVector,
363363
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
364-
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
364+
alg::AbstractODEAlgorithm = DP5(),
365365
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
366366
params = NullParameters(),
367367
rng::AbstractRNG = default_rng(),
@@ -398,7 +398,7 @@ end
398398

399399
function mcsolve(
400400
ens_prob_mc::TimeEvolutionProblem,
401-
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
401+
alg::AbstractODEAlgorithm = DP5(),
402402
ntraj::Int = 500,
403403
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
404404
keep_runs_results = Val(false),

0 commit comments

Comments
 (0)