Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ jobs:
os: 'ubuntu-latest'
arch: 'x64'
group: 'CairoMakie_Ext'
- version: '1'
node:
os: 'ubuntu-latest'
arch: 'x64'
group: 'AutoDiff_Ext'

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make CUDA conversion more general using Adapt.jl. ([#436], [#437])
- Make the generation of `fock` states non-mutating to support Zygote.jl. ([#438])
- Remove Reexport.jl from the dependencies. ([#443])
- Add support for automatic differentiation for `sesolve` and `mesolve`. ([#440])

## [v0.29.1]
Release date: 2025-03-07
Expand Down Expand Up @@ -194,4 +195,5 @@ Release date: 2024-11-13
[#436]: https://github.com/qutip/QuantumToolbox.jl/issues/436
[#437]: https://github.com/qutip/QuantumToolbox.jl/issues/437
[#438]: https://github.com/qutip/QuantumToolbox.jl/issues/438
[#440]: https://github.com/qutip/QuantumToolbox.jl/issues/440
[#443]: https://github.com/qutip/QuantumToolbox.jl/issues/443
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
QuantumToolboxCUDAExt = "CUDA"
QuantumToolboxCairoMakieExt = "CairoMakie"
QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"]
QuantumToolboxChainRulesCoreExt = "ChainRulesCore"

[compat]
ArrayInterface = "6, 7"
CUDA = "5"
CairoMakie = "0.12, 0.13"
ChainRulesCore = "1"
DiffEqBase = "6"
DiffEqCallbacks = "4.2.1 - 4"
DiffEqNoiseProcess = "5"
Expand Down
14 changes: 14 additions & 0 deletions ext/QuantumToolboxChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module QuantumToolboxChainRulesCoreExt

using LinearAlgebra
import QuantumToolbox: QuantumObject
import ChainRulesCore: rrule, NoTangent, Tangent

function rrule(::Type{QuantumObject}, data, type, dimensions)
obj = QuantumObject(data, type, dimensions)
f_pullback(Δobj) = (NoTangent(), Δobj.data, NoTangent(), NoTangent())
f_pullback(Δobj_data::AbstractArray) = (NoTangent(), Δobj_data, NoTangent(), NoTangent())
return obj, f_pullback
end

end
31 changes: 23 additions & 8 deletions src/qobj/quantum_object_evo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@
if α isa Nothing
return QuantumObjectEvolution(op.data, type, op.dimensions)
end
return QuantumObjectEvolution(α * op.data, type, op.dimensions)
return QuantumObjectEvolution(_promote_to_scimloperator(α, op.data), type, op.dimensions)
end

#=
Expand Down Expand Up @@ -397,7 +397,6 @@
)

op = :(op_func_list[$i][1])
data_type = op_type.parameters[1]
dims_expr = (dims_expr..., :($op.dimensions))
func_methods_expr = (func_methods_expr..., :(methods(op_func_list[$i][2], [Any, Real]))) # [Any, Real] means each func must accept 2 arguments
if i == 1
Expand All @@ -409,7 +408,6 @@
(isoper(op_type) || issuper(op_type)) ||
throw(ArgumentError("The element must be a Operator or SuperOperator."))

data_type = op_type.parameters[1]
dims_expr = (dims_expr..., :(op_func_list[$i].dimensions))
if i == 1
first_op = :(op_func_list[$i])
Expand Down Expand Up @@ -445,16 +443,33 @@
T = eltype(op_func[1])
update_func = (a, u, p, t) -> op_func[2](p, t)
if α isa Nothing
return ScalarOperator(zero(T), update_func) * MatrixOperator(op_func[1].data)
return ScalarOperator(zero(T), update_func) * _promote_to_scimloperator(op_func[1].data)
end
return ScalarOperator(zero(T), update_func) * MatrixOperator(α * op_func[1].data)
return ScalarOperator(zero(T), update_func) * _promote_to_scimloperator(α, op_func[1].data)
end

function _make_SciMLOperator(op::QuantumObject, α)
function _make_SciMLOperator(op::AbstractQuantumObject, α)
if α isa Nothing
return MatrixOperator(op.data)
return _promote_to_scimloperator(op.data)
end
return MatrixOperator(α * op.data)
return _promote_to_scimloperator(α, op.data)
end

_promote_to_scimloperator(data::AbstractMatrix) = MatrixOperator(data)
_promote_to_scimloperator(data::AbstractSciMLOperator) = data

Check warning on line 459 in src/qobj/quantum_object_evo.jl

View check run for this annotation

Codecov / codecov/patch

src/qobj/quantum_object_evo.jl#L459

Added line #L459 was not covered by tests
# TODO: The following special cases can be simplified after
# https://github.com/SciML/SciMLOperators.jl/pull/264 is merged
_promote_to_scimloperator(α::Number, data::AbstractMatrix) = MatrixOperator(α * data)
function _promote_to_scimloperator(α::Number, data::MatrixOperator)
isconstant(data) && return MatrixOperator(α * data.A)
return ScaledOperator(α, data) # Going back to the generic case

Check warning on line 465 in src/qobj/quantum_object_evo.jl

View check run for this annotation

Codecov / codecov/patch

src/qobj/quantum_object_evo.jl#L465

Added line #L465 was not covered by tests
end
function _promote_to_scimloperator(α::Number, data::ScaledOperator)
isconstant(data.λ) && return ScaledOperator(α * data.λ, data.L)
return ScaledOperator(data.λ, _promote_to_scimloperator(α, data.L)) # Try to propagate the rule
end
function _promote_to_scimloperator(α::Number, data::AbstractSciMLOperator)
return α * data # Going back to the generic case
end

@doc raw"""
Expand Down
8 changes: 7 additions & 1 deletion src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ function mesolveProblem(
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncMESolve)

tspan = (tlist[1], tlist[end])
prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs3...)

# TODO: Remove this when https://github.com/SciML/SciMLSensitivity.jl/issues/1181 is fixed
if haskey(kwargs3, :sensealg)
prob = ODEProblem{getVal(inplace)}(L, ρ0, tspan, params; kwargs3...)
else
prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs3...)
end

return TimeEvolutionProblem(prob, tlist, L_evo.dimensions, (isoperket = Val(isoperket(ψ0)),))
end
Expand Down
8 changes: 7 additions & 1 deletion src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ function sesolveProblem(
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSESolve)

tspan = (tlist[1], tlist[end])
prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs3...)

# TODO: Remove this when https://github.com/SciML/SciMLSensitivity.jl/issues/1181 is fixed
if haskey(kwargs3, :sensealg)
prob = ODEProblem{getVal(inplace)}(U, ψ0, tspan, params; kwargs3...)
else
prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs3...)
end

return TimeEvolutionProblem(prob, tlist, H_evo.dimensions)
end
Expand Down
2 changes: 1 addition & 1 deletion test/core-test/quantum_objects_evo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@test_throws DimensionMismatch QobjEvo(a, type = SuperOperator)

ψ = fock(10, 3)
@test_throws TypeError QobjEvo(ψ)
@test_throws MethodError QobjEvo(ψ)
end

# unsupported type of dims
Expand Down
5 changes: 5 additions & 0 deletions test/ext-test/cpu/autodiff/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
84 changes: 84 additions & 0 deletions test/ext-test/cpu/autodiff/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
@testset "Zygote.jl Autodiff" verbose=true begin
@testset "sesolve" begin
coef_Ω(p, t) = p[1]

H = QobjEvo(sigmax(), coef_Ω)
ψ0 = fock(2, 1)
t_max = 10

function my_f_sesolve(p)
tlist = range(0, t_max, 100)

sol = sesolve(
H,
ψ0,
tlist,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(projection(2, 0, 0), sol.states[end]))
end

# Analytical solution
my_f_analytic(Ω) = abs2(sin(Ω * t_max))
my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max)

Ω = 1.0
params = [Ω]

my_f_analytic(Ω)
my_f_sesolve(params)

grad_qt = Zygote.gradient(my_f_sesolve, params)[1]
grad_exact = [my_f_analytic_deriv(params[1])]

@test grad_qt ≈ grad_exact atol=1e-6
end

@testset "mesolve" begin
N = 20
a = destroy(N)

coef_Δ(p, t) = p[1]
coef_F(p, t) = p[2]
coef_γ(p, t) = sqrt(p[3])

H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
c_ops = [QobjEvo(a, coef_γ)]
L = liouvillian(H, c_ops)

ψ0 = fock(N, 0)

function my_f_mesolve(p)
tlist = range(0, 40, 100)

sol = mesolve(
L,
ψ0,
tlist,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(a' * a, sol.states[end]))
end

# Analytical solution
n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2))

Δ = 1.0
F = 1.0
γ = 1.0
params = [Δ, F, γ]

# The factor 2 is due to a bug
grad_qt = Zygote.gradient(my_f_mesolve, params)[1]

grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1]

@test grad_qt ≈ grad_exact atol=1e-6
end
end
17 changes: 15 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,20 @@ if (GROUP == "All") || (GROUP == "Code-Quality")
include(joinpath(testdir, "core-test", "code-quality", "code_quality.jl"))
end

if (GROUP == "CairoMakie_Ext")# || (GROUP == "All")
if (GROUP == "AutoDiff_Ext")
Pkg.activate("ext-test/cpu/autodiff")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()

using QuantumToolbox
using Zygote
using Enzyme
using SciMLSensitivity

include(joinpath(testdir, "ext-test", "cpu", "autodiff", "zygote.jl"))
end

if (GROUP == "CairoMakie_Ext")
Pkg.activate("ext-test/cpu/cairomakie")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
Expand All @@ -67,7 +80,7 @@ if (GROUP == "CairoMakie_Ext")# || (GROUP == "All")
include(joinpath(testdir, "ext-test", "cpu", "cairomakie", "cairomakie_ext.jl"))
end

if (GROUP == "CUDA_Ext")# || (GROUP == "All")
if (GROUP == "CUDA_Ext")
Pkg.activate("ext-test/gpu")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
Expand Down
Loading