Skip to content

Commit 9ab6087

Browse files
Add sesolve tests
1 parent 090cd38 commit 9ab6087

File tree

6 files changed

+74
-10
lines changed

6 files changed

+74
-10
lines changed

.github/workflows/CI.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ jobs:
5555
- 'Core'
5656

5757
include:
58+
# for core tests (intermediate versions)
59+
# - version: '1.x'
60+
# node:
61+
# os: 'ubuntu-latest'
62+
# arch: 'x64'
63+
# group: 'Core'
64+
5865
# for extension tests
5966
- version: '1'
6067
node:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3737
QuantumToolboxCUDAExt = "CUDA"
3838
QuantumToolboxCairoMakieExt = "CairoMakie"
3939
QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"]
40-
QuantumToolboxChainRulesCore = "ChainRulesCore"
40+
QuantumToolboxChainRulesCoreExt = "ChainRulesCore"
4141

4242
[compat]
4343
ArrayInterface = "6, 7"

ext/QuantumToolboxChainRulesCore.jl renamed to ext/QuantumToolboxChainRulesCoreExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
module QuantumToolboxChainRulesCore
1+
module QuantumToolboxChainRulesCoreExt
22

33
using LinearAlgebra
44
import QuantumToolbox: QuantumObject
5-
using ChainRulesCore
5+
import ChainRulesCore: rrule, NoTangent, Tangent
66

7-
function ChainRulesCore.rrule(::Type{QuantumObject}, data, type, dimensions)
7+
function rrule(::Type{QuantumObject}, data, type, dimensions)
88
obj = QuantumObject(data, type, dimensions)
99
f_pullback(Δobj) = (NoTangent(), Δobj.data, NoTangent(), NoTangent())
1010
f_pullback(Δobj_data::AbstractArray) = (NoTangent(), Δobj_data, NoTangent(), NoTangent())

src/QuantumToolbox.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,28 @@ import StaticArraysCore: MVector
7373
# to achieve better performances for more massive parallelizations
7474
BLAS.set_num_threads(1)
7575

76+
# TODO: To remove when https://github.com/SciML/SciMLOperators.jl/pull/264 is merged
77+
SCALINGNUMBERTYPES = (:AbstractSciMLScalarOperator, :Number, :UniformScaling)
78+
# Special cases for constant scalars. These simplify the structure when applicable
79+
for T in SCALINGNUMBERTYPES[2:end]
80+
@eval function Base.:*::$T, L::ScaledOperator)
81+
isconstant(L.λ) && return ScaledOperator* L.λ, L.L)
82+
return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule
83+
end
84+
@eval function Base.:*(L::ScaledOperator, α::$T)
85+
isconstant(L.λ) && return ScaledOperator* L.λ, L.L)
86+
return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule
87+
end
88+
@eval function Base.:*::$T, L::MatrixOperator)
89+
isconstant(L) && return MatrixOperator* L.A)
90+
return ScaledOperator(α, L) # Going back to the generic case
91+
end
92+
@eval function Base.:*(L::MatrixOperator, α::$T)
93+
isconstant(L) && return MatrixOperator* L.A)
94+
return ScaledOperator(α, L) # Going back to the generic case
95+
end
96+
end
97+
7698
# Utility
7799
include("utilities.jl")
78100
include("versioninfo.jl")

src/qobj/quantum_object_evo.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution`
397397
)
398398

399399
op = :(op_func_list[$i][1])
400-
data_type = op_type.parameters[1]
401400
dims_expr = (dims_expr..., :($op.dimensions))
402401
func_methods_expr = (func_methods_expr..., :(methods(op_func_list[$i][2], [Any, Real]))) # [Any, Real] means each func must accept 2 arguments
403402
if i == 1
@@ -409,7 +408,6 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution`
409408
(isoper(op_type) || issuper(op_type)) ||
410409
throw(ArgumentError("The element must be a Operator or SuperOperator."))
411410

412-
data_type = op_type.parameters[1]
413411
dims_expr = (dims_expr..., :(op_func_list[$i].dimensions))
414412
if i == 1
415413
first_op = :(op_func_list[$i])

test/ext-test/autodiff/zygote.jl

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,45 @@
1-
@testset "Zygote.jl Autodiff" begin
1+
@testset "Zygote.jl Autodiff" verbose=true begin
2+
@testset "sesolve" begin
3+
coef_Ω(p, t) = p[1]
4+
5+
H = QobjEvo(sigmax(), coef_Ω)
6+
ψ0 = fock(2, 1)
7+
t_max = 10
8+
9+
function my_f_sesolve(p)
10+
tlist = range(0, t_max, 100)
11+
12+
sol = sesolve(
13+
H,
14+
ψ0,
15+
tlist,
16+
progress_bar = Val(false),
17+
params = p,
18+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
19+
)
20+
21+
return real(expect(projection(2, 0, 0), sol.states[end]))
22+
end
23+
24+
# Analytical solution
25+
my_f_analytic(Ω) = abs2(sin* t_max))
26+
my_f_analytic_deriv(Ω) = 2 * t_max * sin* t_max) * cos* t_max)
27+
28+
Ω = 1.0
29+
params = [Ω]
30+
31+
my_f_analytic(Ω)
32+
my_f_sesolve(params)
33+
34+
35+
grad_qt = Zygote.gradient(my_f_sesolve, params)[1] ./ 2
36+
grad_exact = [my_f_analytic_deriv(params[1])]
37+
38+
@test grad_qt grad_exact atol=1e-6
39+
end
40+
241
@testset "mesolve" begin
3-
N = 16
42+
N = 20
443
a = destroy(N)
544

645
coef_Δ(p, t) = p[1]
@@ -35,8 +74,6 @@
3574
F = 1.0
3675
γ = 1.0
3776
params = [Δ, F, γ]
38-
my_f_mesolve(params)
39-
n_ss(Δ, F, γ)
4077

4178
# The factor 2 is due to a bug
4279
grad_qt = Zygote.gradient(my_f_mesolve, params)[1] ./ 2

0 commit comments

Comments
 (0)