Skip to content

Commit cc83820

Browse files
committed
Merge branch 'main' into liouvillian
2 parents c265b77 + d44d251 commit cc83820

File tree

9 files changed

+181
-47
lines changed

9 files changed

+181
-47
lines changed

.github/workflows/Benchmarks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
- uses: actions/checkout@v5
3838
- uses: julia-actions/setup-julia@v2
3939
with:
40-
version: '1'
40+
version: '1.11'
4141
arch: x64
4242
- uses: julia-actions/cache@v2
4343
- name: Run benchmark

.github/workflows/CleanPreviewDoc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
runs-on: ubuntu-latest
1616
steps:
1717
- name: Checkout gh-pages branch
18-
uses: actions/checkout@v4
18+
uses: actions/checkout@v5
1919
with:
2020
ref: gh-pages
2121
- name: Delete preview(s) and reset history

.github/workflows/SpellCheck.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ jobs:
1010
- name: Checkout Actions Repository
1111
uses: actions/checkout@v5
1212
- name: Check spelling
13-
uses: crate-ci/typos@v1.37.2
13+
uses: crate-ci/typos@v1.38.1

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main)
99

10+
- Introduce new methods of `sesolve_map` and `mesolve_map` for advanced usage. Users can now customize their own `iter`ator structure, `prob_func` and `output_func`. ([#565])
1011
- Generalize the definition of `liouvillian`. It no longer expects the Hamiltonian to be Hermitian. ([#541])
1112

1213
## [v0.37.0]
@@ -339,3 +340,4 @@ Release date: 2024-11-13
339340
[#554]: https://github.com/qutip/QuantumToolbox.jl/issues/554
340341
[#555]: https://github.com/qutip/QuantumToolbox.jl/issues/555
341342
[#557]: https://github.com/qutip/QuantumToolbox.jl/issues/557
343+
[#565]: https://github.com/qutip/QuantumToolbox.jl/issues/565

benchmarks/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
35
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"
46
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
7+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
8+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

benchmarks/autodiff.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
function benchmark_autodiff!(SUITE)
2+
# Use harmonic oscillator system for both sesolve and mesolve
3+
N = 20
4+
a = destroy(N)
5+
ψ0 = fock(N, 0)
6+
tlist = range(0, 40, 100)
7+
8+
# ---- SESOLVE ----
9+
# For direct Forward differentiation
10+
function my_f_sesolve_direct(p)
11+
H = p[1] * a' * a + p[2] * (a + a')
12+
sol = sesolve(H, ψ0, tlist, progress_bar = Val(false))
13+
return real(expect(a' * a, sol.states[end]))
14+
end
15+
16+
# For SciMLSensitivity.jl (reverse mode with Zygote and Enzyme)
17+
coef_Δ(p, t) = p[1]
18+
coef_F(p, t) = p[2]
19+
H_evo = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
20+
21+
function my_f_sesolve(p)
22+
sol = sesolve(
23+
H_evo,
24+
ψ0,
25+
tlist,
26+
progress_bar = Val(false),
27+
params = p,
28+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
29+
)
30+
return real(expect(a' * a, sol.states[end]))
31+
end
32+
33+
# ---- MESOLVE ----
34+
# For direct Forward differentiation
35+
function my_f_mesolve_direct(p)
36+
H = p[1] * a' * a + p[2] * (a + a')
37+
c_ops = [sqrt(p[3]) * a]
38+
sol = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false))
39+
return real(expect(a' * a, sol.states[end]))
40+
end
41+
42+
# For SciMLSensitivity.jl (reverse mode with Zygote and Enzyme)
43+
coef_γ(p, t) = sqrt(p[3])
44+
c_ops = [QobjEvo(a, coef_γ)]
45+
L = liouvillian(H_evo, c_ops)
46+
47+
function my_f_mesolve(p)
48+
sol = mesolve(
49+
L,
50+
ψ0,
51+
tlist,
52+
progress_bar = Val(false),
53+
params = p,
54+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
55+
)
56+
return real(expect(a' * a, sol.states[end]))
57+
end
58+
59+
# Parameters for benchmarks
60+
params_sesolve = [1.0, 1.0]
61+
params_mesolve = [1.0, 1.0, 1.0]
62+
63+
# Benchmark sesolve - Forward
64+
SUITE["Autodiff"]["sesolve"]["Forward"] = @benchmarkable ForwardDiff.gradient($my_f_sesolve_direct, $params_sesolve)
65+
66+
# Benchmark sesolve - Reverse (Zygote)
67+
SUITE["Autodiff"]["sesolve"]["Reverse (Zygote)"] = @benchmarkable Zygote.gradient($my_f_sesolve, $params_sesolve)
68+
69+
# Benchmark sesolve - Reverse (Enzyme)
70+
SUITE["Autodiff"]["sesolve"]["Reverse (Enzyme)"] = @benchmarkable Enzyme.autodiff(
71+
Enzyme.set_runtime_activity(Enzyme.Reverse),
72+
Const($my_f_sesolve),
73+
Active,
74+
Duplicated($params_sesolve, dparams_sesolve),
75+
) setup=(dparams_sesolve = Enzyme.make_zero($params_sesolve))
76+
77+
# Benchmark mesolve - Forward
78+
SUITE["Autodiff"]["mesolve"]["Forward"] = @benchmarkable ForwardDiff.gradient($my_f_mesolve_direct, $params_mesolve)
79+
80+
# Benchmark mesolve - Reverse (Zygote)
81+
SUITE["Autodiff"]["mesolve"]["Reverse (Zygote)"] = @benchmarkable Zygote.gradient($my_f_mesolve, $params_mesolve)
82+
83+
# Benchmark mesolve - Reverse (Enzyme)
84+
SUITE["Autodiff"]["mesolve"]["Reverse (Enzyme)"] = @benchmarkable Enzyme.autodiff(
85+
Enzyme.set_runtime_activity(Enzyme.Reverse),
86+
Const($my_f_mesolve),
87+
Active,
88+
Duplicated($params_mesolve, dparams_mesolve),
89+
) setup=(dparams_mesolve = Enzyme.make_zero($params_mesolve))
90+
91+
return nothing
92+
end

benchmarks/runbenchmarks.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ using LinearAlgebra
33
using SparseArrays
44
using QuantumToolbox
55
using SciMLBase: EnsembleSerial, EnsembleThreads
6+
using ForwardDiff
7+
using Zygote
8+
using Enzyme: Enzyme, Const, Active, Duplicated
9+
using SciMLSensitivity: BacksolveAdjoint, EnzymeVJP
610

711
BLAS.set_num_threads(1)
812

@@ -14,13 +18,15 @@ include("dynamical_shifted_fock.jl")
1418
include("eigenvalues.jl")
1519
include("steadystate.jl")
1620
include("timeevolution.jl")
21+
include("autodiff.jl")
1722

1823
benchmark_correlations_and_spectrum!(SUITE)
1924
benchmark_dfd!(SUITE)
2025
benchmark_dsf!(SUITE)
2126
benchmark_eigenvalues!(SUITE)
2227
benchmark_steadystate!(SUITE)
2328
benchmark_timeevolution!(SUITE)
29+
benchmark_autodiff!(SUITE)
2430

2531
BenchmarkTools.tune!(SUITE)
2632
results = BenchmarkTools.run(SUITE, verbose = true)

src/time_evolution/mesolve.jl

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,11 @@ function mesolve_map(
319319
to_dense(T, mat2vec(ket2dm(state).data))
320320
end
321321
end
322-
iter =
323-
params isa NullParameters ? collect(Iterators.product(ψ0_iter, [params])) :
324-
collect(Iterators.product(ψ0_iter, params...))
325-
ntraj = length(iter)
322+
if params isa NullParameters
323+
iter = collect(Iterators.product(ψ0_iter, [params])) |> vec # convert nx1 Matrix into Vector
324+
else
325+
iter = collect(Iterators.product(ψ0_iter, params...))
326+
end
326327

327328
# we disable the progress bar of the mesolveProblem because we use a global progress bar for all the trajectories
328329
prob = mesolveProblem(
@@ -336,35 +337,49 @@ function mesolve_map(
336337
kwargs...,
337338
)
338339

339-
# generate and solve ensemble problem
340-
_output_func = _ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) # setup global progress bar
340+
return mesolve_map(prob, iter, alg, ensemblealg; progress_bar = progress_bar)
341+
end
342+
mesolve_map(
343+
H::Union{AbstractQuantumObject{HOpType},Tuple},
344+
ψ0::QuantumObject{StateOpType},
345+
tlist::AbstractVector,
346+
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
347+
kwargs...,
348+
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} =
349+
mesolve_map(H, [ψ0], tlist, c_ops; kwargs...)
350+
351+
# this method is for advanced usage
352+
# User can define their own iterator structure, prob_func and output_func
353+
# - `prob_func`: Function to use for generating the ODEProblem.
354+
# - `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
355+
#
356+
# Return: An array of TimeEvolutionSol objects with the size same as the given iter.
357+
function mesolve_map(
358+
prob::TimeEvolutionProblem{<:ODEProblem},
359+
iter::AbstractArray,
360+
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
361+
ensemblealg::EnsembleAlgorithm = EnsembleThreads();
362+
prob_func::Union{Function,Nothing} = nothing,
363+
output_func::Union{Tuple,Nothing} = nothing,
364+
progress_bar::Union{Val,Bool} = Val(true),
365+
)
366+
# generate ensemble problem
367+
ntraj = length(iter)
368+
_prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func
369+
_output_func =
370+
isnothing(output_func) ?
371+
_ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) : output_func
341372
ens_prob = TimeEvolutionProblem(
342-
EnsembleProblem(
343-
prob.prob,
344-
prob_func = (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter),
345-
output_func = _output_func[1],
346-
safetycopy = false,
347-
),
373+
EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
348374
prob.times,
349375
prob.dimensions,
350376
(progr = _output_func[2], channel = _output_func[3], isoperket = prob.kwargs.isoperket),
351377
)
378+
352379
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
353380

354381
# handle solution and make it become an Array of TimeEvolutionSol
355382
sol_vec =
356383
[_gen_mesolve_solution(sol[:, i], prob.times, prob.dimensions, prob.kwargs.isoperket) for i in eachindex(sol)] # map is type unstable
357-
if params isa NullParameters # if no parameters specified, just return a Vector
358-
return sol_vec
359-
else
360-
return reshape(sol_vec, size(iter))
361-
end
384+
return reshape(sol_vec, size(iter))
362385
end
363-
mesolve_map(
364-
H::Union{AbstractQuantumObject{HOpType},Tuple},
365-
ψ0::QuantumObject{StateOpType},
366-
tlist::AbstractVector,
367-
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
368-
kwargs...,
369-
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} =
370-
mesolve_map(H, [ψ0], tlist, c_ops; kwargs...)

src/time_evolution/sesolve.jl

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,11 @@ function sesolve_map(
235235
)
236236
# mapping initial states and parameters
237237
ψ0_iter = map(get_data, ψ0)
238-
iter =
239-
params isa NullParameters ? collect(Iterators.product(ψ0_iter, [params])) :
240-
collect(Iterators.product(ψ0_iter, params...))
241-
ntraj = length(iter)
238+
if params isa NullParameters
239+
iter = collect(Iterators.product(ψ0_iter, [params])) |> vec # convert nx1 Matrix into Vector
240+
else
241+
iter = collect(Iterators.product(ψ0_iter, params...))
242+
end
242243

243244
# we disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
244245
prob = sesolveProblem(
@@ -251,28 +252,42 @@ function sesolve_map(
251252
kwargs...,
252253
)
253254

254-
# generate and solve ensemble problem
255-
_output_func = _ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) # setup global progress bar
255+
return sesolve_map(prob, iter, alg, ensemblealg; progress_bar = progress_bar)
256+
end
257+
sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{Ket}, tlist::AbstractVector; kwargs...) =
258+
sesolve_map(H, [ψ0], tlist; kwargs...)
259+
260+
# this method is for advanced usage
261+
# User can define their own iterator structure, prob_func and output_func
262+
# - `prob_func`: Function to use for generating the ODEProblem.
263+
# - `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
264+
#
265+
# Return: An array of TimeEvolutionSol objects with the size same as the given iter.
266+
function sesolve_map(
267+
prob::TimeEvolutionProblem{<:ODEProblem},
268+
iter::AbstractArray,
269+
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
270+
ensemblealg::EnsembleAlgorithm = EnsembleThreads();
271+
prob_func::Union{Function,Nothing} = nothing,
272+
output_func::Union{Tuple,Nothing} = nothing,
273+
progress_bar::Union{Val,Bool} = Val(true),
274+
)
275+
# generate ensemble problem
276+
ntraj = length(iter)
277+
_prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func
278+
_output_func =
279+
isnothing(output_func) ?
280+
_ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) : output_func
256281
ens_prob = TimeEvolutionProblem(
257-
EnsembleProblem(
258-
prob.prob,
259-
prob_func = (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter),
260-
output_func = _output_func[1],
261-
safetycopy = false,
262-
),
282+
EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
263283
prob.times,
264284
prob.dimensions,
265285
(progr = _output_func[2], channel = _output_func[3]),
266286
)
287+
267288
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
268289

269290
# handle solution and make it become an Array of TimeEvolutionSol
270291
sol_vec = [_gen_sesolve_solution(sol[:, i], prob.times, prob.dimensions) for i in eachindex(sol)] # map is type unstable
271-
if params isa NullParameters # if no parameters specified, just return a Vector
272-
return sol_vec
273-
else
274-
return reshape(sol_vec, size(iter))
275-
end
292+
return reshape(sol_vec, size(iter))
276293
end
277-
sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{Ket}, tlist::AbstractVector; kwargs...) =
278-
sesolve_map(H, [ψ0], tlist; kwargs...)

0 commit comments

Comments
 (0)