Skip to content

Commit 090cd38

Browse files
Working mesolve
1 parent 6898156 commit 090cd38

File tree

6 files changed

+90
-9
lines changed

6 files changed

+90
-9
lines changed

.github/workflows/CI.yml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,17 @@ jobs:
5555
- 'Core'
5656

5757
include:
58-
# for core tests (intermediate versions)
59-
# - version: '1.x'
60-
# node:
61-
# os: 'ubuntu-latest'
62-
# arch: 'x64'
63-
# group: 'Core'
64-
6558
# for extension tests
6659
- version: '1'
6760
node:
6861
os: 'ubuntu-latest'
6962
arch: 'x64'
7063
group: 'CairoMakie_Ext'
64+
- version: '1'
65+
node:
66+
os: 'ubuntu-latest'
67+
arch: 'x64'
68+
group: 'AutoDiff_Ext'
7169

7270
steps:
7371
- uses: actions/checkout@v4

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,19 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3131
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
3232
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
3333
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
34+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3435

3536
[extensions]
3637
QuantumToolboxCUDAExt = "CUDA"
3738
QuantumToolboxCairoMakieExt = "CairoMakie"
3839
QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"]
40+
QuantumToolboxChainRulesCore = "ChainRulesCore"
3941

4042
[compat]
4143
ArrayInterface = "6, 7"
4244
CUDA = "5"
4345
CairoMakie = "0.12, 0.13"
46+
ChainRulesCore = "1"
4447
DiffEqBase = "6"
4548
DiffEqCallbacks = "4.2.1 - 4"
4649
DiffEqNoiseProcess = "5"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module QuantumToolboxChainRulesCore
2+
3+
using LinearAlgebra
4+
import QuantumToolbox: QuantumObject
5+
using ChainRulesCore
6+
7+
function ChainRulesCore.rrule(::Type{QuantumObject}, data, type, dimensions)
8+
obj = QuantumObject(data, type, dimensions)
9+
f_pullback(Δobj) = (NoTangent(), Δobj.data, NoTangent(), NoTangent())
10+
f_pullback(Δobj_data::AbstractArray) = (NoTangent(), Δobj_data, NoTangent(), NoTangent())
11+
return obj, f_pullback
12+
end
13+
14+
end
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[deps]
2+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3+
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"
4+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
5+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/ext-test/autodiff/zygote.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
@testset "Zygote.jl Autodiff" begin
2+
@testset "mesolve" begin
3+
N = 16
4+
a = destroy(N)
5+
6+
coef_Δ(p, t) = p[1]
7+
coef_F(p, t) = p[2]
8+
coef_γ(p, t) = sqrt(p[3])
9+
10+
H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
11+
c_ops = [QobjEvo(a, coef_γ)]
12+
L = liouvillian(H, c_ops)
13+
14+
ψ0 = fock(N, 0)
15+
16+
function my_f_mesolve(p)
17+
tlist = range(0, 40, 100)
18+
19+
sol = mesolve(
20+
L,
21+
ψ0,
22+
tlist,
23+
progress_bar = Val(false),
24+
params = p,
25+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
26+
)
27+
28+
return real(expect(a' * a, sol.states[end]))
29+
end
30+
31+
# Analytical solution
32+
n_ss(Δ, F, γ) = abs2(F /+ 1im * γ / 2))
33+
34+
Δ = 1.0
35+
F = 1.0
36+
γ = 1.0
37+
params = [Δ, F, γ]
38+
my_f_mesolve(params)
39+
n_ss(Δ, F, γ)
40+
41+
# The factor 2 is due to a bug
42+
grad_qt = Zygote.gradient(my_f_mesolve, params)[1] ./ 2
43+
44+
grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1]
45+
46+
@test grad_qt grad_exact atol=1e-6
47+
end
48+
end

test/runtests.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,20 @@ if (GROUP == "All") || (GROUP == "Code-Quality")
5050
include(joinpath(testdir, "core-test", "code-quality", "code_quality.jl"))
5151
end
5252

53-
if (GROUP == "CairoMakie_Ext")# || (GROUP == "All")
53+
if (GROUP == "AutoDiff_Ext")
54+
Pkg.activate("ext-test/cpu/autodiff")
55+
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
56+
Pkg.instantiate()
57+
58+
using QuantumToolbox
59+
using Zygote
60+
using Enzyme
61+
using SciMLSensitivity
62+
63+
include(joinpath(testdir, "ext-test", "autodiff", "zygote.jl"))
64+
end
65+
66+
if (GROUP == "CairoMakie_Ext")
5467
Pkg.activate("ext-test/cpu/cairomakie")
5568
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
5669
Pkg.instantiate()
@@ -62,7 +75,7 @@ if (GROUP == "CairoMakie_Ext")# || (GROUP == "All")
6275
include(joinpath(testdir, "ext-test", "cpu", "cairomakie", "cairomakie_ext.jl"))
6376
end
6477

65-
if (GROUP == "CUDA_Ext")# || (GROUP == "All")
78+
if (GROUP == "CUDA_Ext")
6679
Pkg.activate("ext-test/gpu")
6780
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
6881
Pkg.instantiate()

0 commit comments

Comments
 (0)