Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ jobs:
os: 'ubuntu-latest'
arch: 'x64'
group: 'CairoMakie_Ext'
- version: '1'
node:
os: 'ubuntu-latest'
arch: 'x64'
group: 'AutoDiff_Ext'

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make CUDA conversion more general using Adapt.jl. ([#436], [#437])
- Make the generation of `fock` states non-mutating to support Zygote.jl. ([#438])
- Remove Reexport.jl from the dependencies. ([#443])
- Add support for automatic differentiation for `sesolve` and `mesolve`. ([#440])

## [v0.29.1]
Release date: 2025-03-07
Expand Down Expand Up @@ -194,4 +195,5 @@ Release date: 2024-11-13
[#436]: https://github.com/qutip/QuantumToolbox.jl/issues/436
[#437]: https://github.com/qutip/QuantumToolbox.jl/issues/437
[#438]: https://github.com/qutip/QuantumToolbox.jl/issues/438
[#440]: https://github.com/qutip/QuantumToolbox.jl/issues/440
[#443]: https://github.com/qutip/QuantumToolbox.jl/issues/443
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
QuantumToolboxCUDAExt = "CUDA"
QuantumToolboxCairoMakieExt = "CairoMakie"
QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"]
QuantumToolboxChainRulesCoreExt = "ChainRulesCore"

[compat]
ArrayInterface = "6, 7"
CUDA = "5"
CairoMakie = "0.12, 0.13"
ChainRulesCore = "1"
DiffEqBase = "6"
DiffEqCallbacks = "4.2.1 - 4"
DiffEqNoiseProcess = "5"
Expand Down
14 changes: 14 additions & 0 deletions ext/QuantumToolboxChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module QuantumToolboxChainRulesCoreExt

using LinearAlgebra
import QuantumToolbox: QuantumObject
import ChainRulesCore: rrule, NoTangent, Tangent

function rrule(::Type{QuantumObject}, data, type, dimensions)
obj = QuantumObject(data, type, dimensions)
f_pullback(Δobj) = (NoTangent(), Δobj.data, NoTangent(), NoTangent())
f_pullback(Δobj_data::AbstractArray) = (NoTangent(), Δobj_data, NoTangent(), NoTangent())
return obj, f_pullback
end

end
22 changes: 22 additions & 0 deletions src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ export permute
# SciMLOperators
export cache_operator, iscached, isconstant

# TODO: To remove when https://github.com/SciML/SciMLOperators.jl/pull/264 is merged
SCALINGNUMBERTYPES = (:AbstractSciMLScalarOperator, :Number, :UniformScaling)
# Special cases for constant scalars. These simplify the structure when applicable
for T in SCALINGNUMBERTYPES[2:end]
@eval function Base.:*(α::$T, L::ScaledOperator)
isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L)
return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule
end
@eval function Base.:*(L::ScaledOperator, α::$T)
isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L)
return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule
end
@eval function Base.:*(α::$T, L::MatrixOperator)
isconstant(L) && return MatrixOperator(α * L.A)
return ScaledOperator(α, L) # Going back to the generic case
end
@eval function Base.:*(L::MatrixOperator, α::$T)
isconstant(L) && return MatrixOperator(α * L.A)
return ScaledOperator(α, L) # Going back to the generic case
end
end

# Utility
include("utilities.jl")
include("versioninfo.jl")
Expand Down
2 changes: 0 additions & 2 deletions src/qobj/quantum_object_evo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution`
)

op = :(op_func_list[$i][1])
data_type = op_type.parameters[1]
dims_expr = (dims_expr..., :($op.dimensions))
func_methods_expr = (func_methods_expr..., :(methods(op_func_list[$i][2], [Any, Real]))) # [Any, Real] means each func must accept 2 arguments
if i == 1
Expand All @@ -409,7 +408,6 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution`
(isoper(op_type) || issuper(op_type)) ||
throw(ArgumentError("The element must be a Operator or SuperOperator."))

data_type = op_type.parameters[1]
dims_expr = (dims_expr..., :(op_func_list[$i].dimensions))
if i == 1
first_op = :(op_func_list[$i])
Expand Down
8 changes: 7 additions & 1 deletion src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ function mesolveProblem(
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncMESolve)

tspan = (tlist[1], tlist[end])
prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs3...)

# TODO: Remove this when https://github.com/SciML/SciMLSensitivity.jl/issues/1181 is fixed
if haskey(kwargs3, :sensealg)
prob = ODEProblem{getVal(inplace)}(L, ρ0, tspan, params; kwargs3...)
else
prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs3...)
end

return TimeEvolutionProblem(prob, tlist, L_evo.dimensions, (isoperket = Val(isoperket(ψ0)),))
end
Expand Down
8 changes: 7 additions & 1 deletion src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ function sesolveProblem(
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSESolve)

tspan = (tlist[1], tlist[end])
prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs3...)

# TODO: Remove this when https://github.com/SciML/SciMLSensitivity.jl/issues/1181 is fixed
if haskey(kwargs3, :sensealg)
prob = ODEProblem{getVal(inplace)}(U, ψ0, tspan, params; kwargs3...)
else
prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs3...)
end

return TimeEvolutionProblem(prob, tlist, H_evo.dimensions)
end
Expand Down
5 changes: 5 additions & 0 deletions test/ext-test/cpu/autodiff/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
84 changes: 84 additions & 0 deletions test/ext-test/cpu/autodiff/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
@testset "Zygote.jl Autodiff" verbose=true begin
@testset "sesolve" begin
coef_Ω(p, t) = p[1]

H = QobjEvo(sigmax(), coef_Ω)
ψ0 = fock(2, 1)
t_max = 10

function my_f_sesolve(p)
tlist = range(0, t_max, 100)

sol = sesolve(
H,
ψ0,
tlist,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(projection(2, 0, 0), sol.states[end]))
end

# Analytical solution
my_f_analytic(Ω) = abs2(sin(Ω * t_max))
my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max)

Ω = 1.0
params = [Ω]

my_f_analytic(Ω)
my_f_sesolve(params)

grad_qt = Zygote.gradient(my_f_sesolve, params)[1]
grad_exact = [my_f_analytic_deriv(params[1])]

@test grad_qt ≈ grad_exact atol=1e-6
end

@testset "mesolve" begin
N = 20
a = destroy(N)

coef_Δ(p, t) = p[1]
coef_F(p, t) = p[2]
coef_γ(p, t) = sqrt(p[3])

H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
c_ops = [QobjEvo(a, coef_γ)]
L = liouvillian(H, c_ops)

ψ0 = fock(N, 0)

function my_f_mesolve(p)
tlist = range(0, 40, 100)

sol = mesolve(
L,
ψ0,
tlist,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(a' * a, sol.states[end]))
end

# Analytical solution
n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2))

Δ = 1.0
F = 1.0
γ = 1.0
params = [Δ, F, γ]

# The factor 2 is due to a bug
grad_qt = Zygote.gradient(my_f_mesolve, params)[1]

grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1]

@test grad_qt ≈ grad_exact atol=1e-6
end
end
17 changes: 15 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,20 @@ if (GROUP == "All") || (GROUP == "Code-Quality")
include(joinpath(testdir, "core-test", "code-quality", "code_quality.jl"))
end

if (GROUP == "CairoMakie_Ext")# || (GROUP == "All")
if (GROUP == "AutoDiff_Ext")
Pkg.activate("ext-test/cpu/autodiff")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()

using QuantumToolbox
using Zygote
using Enzyme
using SciMLSensitivity

include(joinpath(testdir, "ext-test", "cpu", "autodiff", "zygote.jl"))
end

if (GROUP == "CairoMakie_Ext")
Pkg.activate("ext-test/cpu/cairomakie")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
Expand All @@ -67,7 +80,7 @@ if (GROUP == "CairoMakie_Ext")# || (GROUP == "All")
include(joinpath(testdir, "ext-test", "cpu", "cairomakie", "cairomakie_ext.jl"))
end

if (GROUP == "CUDA_Ext")# || (GROUP == "All")
if (GROUP == "CUDA_Ext")
Pkg.activate("ext-test/gpu")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
Expand Down