Skip to content

Commit acc0b2c

Browse files
Add Enzyme benchmarks
1 parent ed5783d commit acc0b2c

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

benchmarks/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
34
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
45
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"
56
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

benchmarks/autodiff.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,37 @@ function benchmark_autodiff!(SUITE)
6060
params_sesolve = [1.0, 1.0]
6161
params_mesolve = [1.0, 1.0, 1.0]
6262

63+
# Pre-allocate gradient arrays for Enzyme
64+
dparams_sesolve = Enzyme.make_zero(params_sesolve)
65+
dparams_mesolve = Enzyme.make_zero(params_mesolve)
66+
6367
# Benchmark sesolve - Forward
6468
SUITE["Autodiff"]["sesolve"]["Forward"] = @benchmarkable ForwardDiff.gradient($my_f_sesolve_direct, $params_sesolve)
6569

6670
# Benchmark sesolve - Reverse (Zygote)
67-
SUITE["Autodiff"]["sesolve"]["Reverse"] = @benchmarkable Zygote.gradient($my_f_sesolve, $params_sesolve)
71+
SUITE["Autodiff"]["sesolve"]["Reverse (Zygote)"] = @benchmarkable Zygote.gradient($my_f_sesolve, $params_sesolve)
72+
73+
# Benchmark sesolve - Reverse (Enzyme)
74+
SUITE["Autodiff"]["sesolve"]["Reverse (Enzyme)"] = @benchmarkable Enzyme.autodiff(
75+
Enzyme.set_runtime_activity(Enzyme.Reverse),
76+
Const($my_f_sesolve),
77+
Active,
78+
Duplicated($params_sesolve, $dparams_sesolve),
79+
)
6880

6981
# Benchmark mesolve - Forward
7082
SUITE["Autodiff"]["mesolve"]["Forward"] = @benchmarkable ForwardDiff.gradient($my_f_mesolve_direct, $params_mesolve)
7183

7284
# Benchmark mesolve - Reverse (Zygote)
73-
SUITE["Autodiff"]["mesolve"]["Reverse"] = @benchmarkable Zygote.gradient($my_f_mesolve, $params_mesolve)
85+
SUITE["Autodiff"]["mesolve"]["Reverse (Zygote)"] = @benchmarkable Zygote.gradient($my_f_mesolve, $params_mesolve)
86+
87+
# Benchmark mesolve - Reverse (Enzyme)
88+
SUITE["Autodiff"]["mesolve"]["Reverse (Enzyme)"] = @benchmarkable Enzyme.autodiff(
89+
Enzyme.set_runtime_activity(Enzyme.Reverse),
90+
Const($my_f_mesolve),
91+
Active,
92+
Duplicated($params_mesolve, $dparams_mesolve),
93+
)
7494

7595
return nothing
7696
end

benchmarks/runbenchmarks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using QuantumToolbox
55
using SciMLBase: EnsembleSerial, EnsembleThreads
66
using ForwardDiff
77
using Zygote
8+
using Enzyme: Enzyme, Const, Active, Duplicated
89
using SciMLSensitivity: BacksolveAdjoint, EnzymeVJP
910

1011
BLAS.set_num_threads(1)

0 commit comments

Comments
 (0)