Skip to content

Commit 77ecf53

Browse files
Add autodiff benchmarks (#564)
1 parent 91fd289 commit 77ecf53

File tree

4 files changed

+103
-1
lines changed

4 files changed

+103
-1
lines changed

.github/workflows/Benchmarks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
- uses: actions/checkout@v5
3838
- uses: julia-actions/setup-julia@v2
3939
with:
40-
version: '1'
40+
version: '1.11'
4141
arch: x64
4242
- uses: julia-actions/cache@v2
4343
- name: Run benchmark

benchmarks/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
35
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"
46
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
7+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
8+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

benchmarks/autodiff.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
function benchmark_autodiff!(SUITE)
2+
# Use harmonic oscillator system for both sesolve and mesolve
3+
N = 20
4+
a = destroy(N)
5+
ψ0 = fock(N, 0)
6+
tlist = range(0, 40, 100)
7+
8+
# ---- SESOLVE ----
9+
# For direct Forward differentiation
10+
function my_f_sesolve_direct(p)
11+
H = p[1] * a' * a + p[2] * (a + a')
12+
sol = sesolve(H, ψ0, tlist, progress_bar = Val(false))
13+
return real(expect(a' * a, sol.states[end]))
14+
end
15+
16+
# For SciMLSensitivity.jl (reverse mode with Zygote and Enzyme)
17+
coef_Δ(p, t) = p[1]
18+
coef_F(p, t) = p[2]
19+
H_evo = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
20+
21+
function my_f_sesolve(p)
22+
sol = sesolve(
23+
H_evo,
24+
ψ0,
25+
tlist,
26+
progress_bar = Val(false),
27+
params = p,
28+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
29+
)
30+
return real(expect(a' * a, sol.states[end]))
31+
end
32+
33+
# ---- MESOLVE ----
34+
# For direct Forward differentiation
35+
function my_f_mesolve_direct(p)
36+
H = p[1] * a' * a + p[2] * (a + a')
37+
c_ops = [sqrt(p[3]) * a]
38+
sol = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false))
39+
return real(expect(a' * a, sol.states[end]))
40+
end
41+
42+
# For SciMLSensitivity.jl (reverse mode with Zygote and Enzyme)
43+
coef_γ(p, t) = sqrt(p[3])
44+
c_ops = [QobjEvo(a, coef_γ)]
45+
L = liouvillian(H_evo, c_ops)
46+
47+
function my_f_mesolve(p)
48+
sol = mesolve(
49+
L,
50+
ψ0,
51+
tlist,
52+
progress_bar = Val(false),
53+
params = p,
54+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
55+
)
56+
return real(expect(a' * a, sol.states[end]))
57+
end
58+
59+
# Parameters for benchmarks
60+
params_sesolve = [1.0, 1.0]
61+
params_mesolve = [1.0, 1.0, 1.0]
62+
63+
# Benchmark sesolve - Forward
64+
SUITE["Autodiff"]["sesolve"]["Forward"] = @benchmarkable ForwardDiff.gradient($my_f_sesolve_direct, $params_sesolve)
65+
66+
# Benchmark sesolve - Reverse (Zygote)
67+
SUITE["Autodiff"]["sesolve"]["Reverse (Zygote)"] = @benchmarkable Zygote.gradient($my_f_sesolve, $params_sesolve)
68+
69+
# Benchmark sesolve - Reverse (Enzyme)
70+
SUITE["Autodiff"]["sesolve"]["Reverse (Enzyme)"] = @benchmarkable Enzyme.autodiff(
71+
Enzyme.set_runtime_activity(Enzyme.Reverse),
72+
Const($my_f_sesolve),
73+
Active,
74+
Duplicated($params_sesolve, dparams_sesolve),
75+
) setup=(dparams_sesolve = Enzyme.make_zero($params_sesolve))
76+
77+
# Benchmark mesolve - Forward
78+
SUITE["Autodiff"]["mesolve"]["Forward"] = @benchmarkable ForwardDiff.gradient($my_f_mesolve_direct, $params_mesolve)
79+
80+
# Benchmark mesolve - Reverse (Zygote)
81+
SUITE["Autodiff"]["mesolve"]["Reverse (Zygote)"] = @benchmarkable Zygote.gradient($my_f_mesolve, $params_mesolve)
82+
83+
# Benchmark mesolve - Reverse (Enzyme)
84+
SUITE["Autodiff"]["mesolve"]["Reverse (Enzyme)"] = @benchmarkable Enzyme.autodiff(
85+
Enzyme.set_runtime_activity(Enzyme.Reverse),
86+
Const($my_f_mesolve),
87+
Active,
88+
Duplicated($params_mesolve, dparams_mesolve),
89+
) setup=(dparams_mesolve = Enzyme.make_zero($params_mesolve))
90+
91+
return nothing
92+
end

benchmarks/runbenchmarks.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ using LinearAlgebra
33
using SparseArrays
44
using QuantumToolbox
55
using SciMLBase: EnsembleSerial, EnsembleThreads
6+
using ForwardDiff
7+
using Zygote
8+
using Enzyme: Enzyme, Const, Active, Duplicated
9+
using SciMLSensitivity: BacksolveAdjoint, EnzymeVJP
610

711
BLAS.set_num_threads(1)
812

@@ -14,13 +18,15 @@ include("dynamical_shifted_fock.jl")
1418
include("eigenvalues.jl")
1519
include("steadystate.jl")
1620
include("timeevolution.jl")
21+
include("autodiff.jl")
1722

1823
benchmark_correlations_and_spectrum!(SUITE)
1924
benchmark_dfd!(SUITE)
2025
benchmark_dsf!(SUITE)
2126
benchmark_eigenvalues!(SUITE)
2227
benchmark_steadystate!(SUITE)
2328
benchmark_timeevolution!(SUITE)
29+
benchmark_autodiff!(SUITE)
2430

2531
BenchmarkTools.tune!(SUITE)
2632
results = BenchmarkTools.run(SUITE, verbose = true)

0 commit comments

Comments
 (0)