Skip to content

Commit e2e398a

Browse files
Add ForwardDiff.jl tests
1 parent a52ccf6 commit e2e398a

File tree

4 files changed

+119
-86
lines changed

4 files changed

+119
-86
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
34
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"
45
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
5-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
@testset "Autodiff" verbose=true begin
2+
@testset "sesolve" verbose=true begin
3+
ψ0 = fock(2, 1)
4+
t_max = 10
5+
tlist = range(0, t_max, 100)
6+
7+
# For direct Forward differentiation
8+
function my_f_sesolve_direct(p)
9+
H = p[1] * sigmax()
10+
sol = sesolve(H, ψ0, tlist, progress_bar = Val(false))
11+
12+
return real(expect(projection(2, 0, 0), sol.states[end]))
13+
end
14+
15+
# For SciMLSensitivity.jl
16+
coef_Ω(p, t) = p[1]
17+
H_evo = QobjEvo(sigmax(), coef_Ω)
18+
19+
function my_f_sesolve(p)
20+
sol = sesolve(
21+
H_evo,
22+
ψ0,
23+
tlist,
24+
progress_bar = Val(false),
25+
params = p,
26+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
27+
)
28+
29+
return real(expect(projection(2, 0, 0), sol.states[end]))
30+
end
31+
32+
# Analytical solution
33+
my_f_analytic(Ω) = abs2(sin* t_max))
34+
my_f_analytic_deriv(Ω) = 2 * t_max * sin* t_max) * cos* t_max)
35+
36+
Ω = 1.0
37+
params = [Ω]
38+
39+
my_f_sesolve_direct(params)
40+
my_f_sesolve(params)
41+
42+
grad_exact = [my_f_analytic_deriv(params[1])]
43+
44+
@testset "ForwardDiff.jl" begin
45+
grad_qt = ForwardDiff.gradient(my_f_sesolve_direct, params)
46+
47+
@test grad_qt grad_exact atol=1e-6
48+
end
49+
50+
@testset "Zygote.jl" begin
51+
grad_qt = Zygote.gradient(my_f_sesolve, params)[1]
52+
53+
@test grad_qt grad_exact atol=1e-6
54+
end
55+
end
56+
57+
@testset "mesolve" verbose=true begin
58+
N = 20
59+
a = destroy(N)
60+
ψ0 = fock(N, 0)
61+
tlist = range(0, 40, 100)
62+
63+
# For direct Forward differentiation
64+
function my_f_mesolve_direct(p)
65+
H = p[1] * a' * a + p[2] * (a + a')
66+
c_ops = [sqrt(p[3]) * a]
67+
sol = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false))
68+
return real(expect(a' * a, sol.states[end]))
69+
end
70+
71+
# For SciMLSensitivity.jl
72+
coef_Δ(p, t) = p[1]
73+
coef_F(p, t) = p[2]
74+
coef_γ(p, t) = sqrt(p[3])
75+
H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
76+
c_ops = [QobjEvo(a, coef_γ)]
77+
L = liouvillian(H, c_ops)
78+
79+
function my_f_mesolve(p)
80+
sol = mesolve(
81+
L,
82+
ψ0,
83+
tlist,
84+
progress_bar = Val(false),
85+
params = p,
86+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
87+
)
88+
89+
return real(expect(a' * a, sol.states[end]))
90+
end
91+
92+
# Analytical solution
93+
n_ss(Δ, F, γ) = abs2(F /+ 1im * γ / 2))
94+
95+
Δ = 1.0
96+
F = 1.0
97+
γ = 1.0
98+
params = [Δ, F, γ]
99+
100+
my_f_mesolve_direct(params)
101+
my_f_mesolve(params)
102+
103+
grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1]
104+
105+
@testset "ForwardDiff.jl" begin
106+
grad_qt = ForwardDiff.gradient(my_f_mesolve_direct, params)
107+
@test grad_qt grad_exact atol=1e-6
108+
end
109+
110+
@testset "Zygote.jl" begin
111+
grad_qt = Zygote.gradient(my_f_mesolve, params)[1]
112+
@test grad_qt grad_exact atol=1e-6
113+
end
114+
end
115+
end

test/ext-test/cpu/autodiff/zygote.jl

Lines changed: 0 additions & 84 deletions
This file was deleted.

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ if (GROUP == "AutoDiff_Ext")
4242
Pkg.instantiate()
4343

4444
using QuantumToolbox
45+
using ForwardDiff
4546
using Zygote
4647
using Enzyme
4748
using SciMLSensitivity
4849

4950
QuantumToolbox.about()
5051

51-
include(joinpath(testdir, "ext-test", "cpu", "autodiff", "zygote.jl"))
52+
include(joinpath(testdir, "ext-test", "cpu", "autodiff", "autodiff.jl"))
5253
end
5354

5455
if (GROUP == "Makie_Ext")

0 commit comments

Comments
 (0)