Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main)

- Improve Bloch sphere rendering for animation. ([#520])
- Add support to `Enzyme.jl` for `sesolve` and `mesolve`. ([#531])

## [v0.34.0]
Release date: 2025-07-29
Expand Down Expand Up @@ -295,3 +296,4 @@ Release date: 2024-11-13
[#515]: https://github.com/qutip/QuantumToolbox.jl/issues/515
[#517]: https://github.com/qutip/QuantumToolbox.jl/issues/517
[#520]: https://github.com/qutip/QuantumToolbox.jl/issues/520
[#531]: https://github.com/qutip/QuantumToolbox.jl/issues/531
18 changes: 14 additions & 4 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ function mesolve(
kwargs...,
)

# Move sensealg argument to solve for Enzyme.jl support.
# TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed.
sensealg = get(kwargs, :sensealg, nothing)
kwargs_filtered = isnothing(sensealg) ? kwargs : Base.structdiff((; kwargs...), (sensealg = sensealg,))

prob = mesolveProblem(
H,
ψ0,
Expand All @@ -188,14 +193,19 @@ function mesolve(
params = params,
progress_bar = progress_bar,
inplace = inplace,
kwargs...,
kwargs_filtered...,
)

return mesolve(prob, alg)
# TODO: Remove sensealg when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed
if isnothing(sensealg)
return mesolve(prob, alg)
else
return mesolve(prob, alg; sensealg = sensealg)
end
end

function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
sol = solve(prob.prob, alg)
function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(); kwargs...)
sol = solve(prob.prob, alg; kwargs...)

# No type instabilities since `isoperket` is a Val, and so it is known at compile time
if getVal(prob.kwargs.isoperket)
Expand Down
19 changes: 15 additions & 4 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ function sesolve(
inplace::Union{Val,Bool} = Val(true),
kwargs...,
)

# Move sensealg argument to solve for Enzyme.jl support.
# TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed.
sensealg = get(kwargs, :sensealg, nothing)
kwargs_filtered = isnothing(sensealg) ? kwargs : Base.structdiff((; kwargs...), (sensealg = sensealg,))

prob = sesolveProblem(
H,
ψ0,
Expand All @@ -143,14 +149,19 @@ function sesolve(
params = params,
progress_bar = progress_bar,
inplace = inplace,
kwargs...,
kwargs_filtered...,
)

return sesolve(prob, alg)
# TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed.
if isnothing(sensealg)
return sesolve(prob, alg)
else
return sesolve(prob, alg; sensealg = sensealg)
end
end

function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
sol = solve(prob.prob, alg)
function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(); kwargs...)
sol = solve(prob.prob, alg; kwargs...)

ψt = map(ϕ -> QuantumObject(ϕ, type = Ket(), dims = prob.dimensions), sol.u)

Expand Down
158 changes: 92 additions & 66 deletions test/ext-test/cpu/autodiff/autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,77 @@
@testset "Autodiff" verbose=true begin
@testset "sesolve" verbose=true begin
ψ0 = fock(2, 1)
t_max = 10
tlist = range(0, t_max, 100)
# ---- SESOLVE ----
const ψ0_sesolve = fock(2, 1)
t_max = 10
const tlist_sesolve = range(0, t_max, 100)

# For direct Forward differentiation
function my_f_sesolve_direct(p)
H = p[1] * sigmax()
sol = sesolve(H, ψ0, tlist, progress_bar = Val(false))
# For direct Forward differentiation
function my_f_sesolve_direct(p)
H = p[1] * sigmax()
sol = sesolve(H, ψ0_sesolve, tlist_sesolve, progress_bar = Val(false))

return real(expect(projection(2, 0, 0), sol.states[end]))
end
return real(expect(projection(2, 0, 0), sol.states[end]))
end

# For SciMLSensitivity.jl
coef_Ω(p, t) = p[1]
H_evo = QobjEvo(sigmax(), coef_Ω)

function my_f_sesolve(p)
sol = sesolve(
H_evo,
ψ0,
tlist,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(projection(2, 0, 0), sol.states[end]))
end
# For SciMLSensitivity.jl
coef_Ω(p, t) = p[1]
const H_evo = QobjEvo(sigmax(), coef_Ω)

function my_f_sesolve(p)
sol = sesolve(
H_evo,
ψ0_sesolve,
tlist_sesolve,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(projection(2, 0, 0), sol.states[end]))
end

# Analytical solution
my_f_analytic(Ω) = abs2(sin(Ω * t_max))
my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max)
# Analytical solution
my_f_analytic(Ω) = abs2(sin(Ω * t_max))
my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max)

# ---- MESOLVE ----
const N = 20
const a = destroy(N)
const ψ0_mesolve = fock(N, 0)
const tlist_mesolve = range(0, 40, 100)

# For direct Forward differentiation
function my_f_mesolve_direct(p)
H = p[1] * a' * a + p[2] * (a + a')
c_ops = [sqrt(p[3]) * a]
sol = mesolve(H, ψ0_mesolve, tlist_mesolve, c_ops, progress_bar = Val(false))
return real(expect(a' * a, sol.states[end]))
end

# For SciMLSensitivity.jl
coef_Δ(p, t) = p[1]
coef_F(p, t) = p[2]
coef_γ(p, t) = sqrt(p[3])
H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
c_ops = [QobjEvo(a, coef_γ)]
const L = liouvillian(H, c_ops)

function my_f_mesolve(p)
sol = mesolve(
L,
ψ0_mesolve,
tlist_mesolve,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(a' * a, sol.states[end]))
end

# Analytical solution
n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2))

@testset "Autodiff" verbose=true begin
@testset "sesolve" verbose=true begin
Ω = 1.0
params = [Ω]

Expand All @@ -52,46 +91,21 @@

@test grad_qt ≈ grad_exact atol=1e-6
end
end

@testset "mesolve" verbose=true begin
N = 20
a = destroy(N)
ψ0 = fock(N, 0)
tlist = range(0, 40, 100)

# For direct Forward differentiation
function my_f_mesolve_direct(p)
H = p[1] * a' * a + p[2] * (a + a')
c_ops = [sqrt(p[3]) * a]
sol = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false))
return real(expect(a' * a, sol.states[end]))
end
@testset "Enzyme.jl" begin
dparams = Enzyme.make_zero(params)
Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.Reverse),
my_f_sesolve,
Active,
Duplicated(params, dparams),
)[1]

# For SciMLSensitivity.jl
coef_Δ(p, t) = p[1]
coef_F(p, t) = p[2]
coef_γ(p, t) = sqrt(p[3])
H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
c_ops = [QobjEvo(a, coef_γ)]
L = liouvillian(H, c_ops)

function my_f_mesolve(p)
sol = mesolve(
L,
ψ0,
tlist,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(a' * a, sol.states[end]))
@test dparams ≈ grad_exact atol=1e-6
end
end

# Analytical solution
n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2))

@testset "mesolve" verbose=true begin
Δ = 1.0
F = 1.0
γ = 1.0
Expand All @@ -111,5 +125,17 @@
grad_qt = Zygote.gradient(my_f_mesolve, params)[1]
@test grad_qt ≈ grad_exact atol=1e-6
end

@testset "Enzyme.jl" begin
dparams = Enzyme.make_zero(params)
Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.Reverse),
my_f_mesolve,
Active,
Duplicated(params, dparams),
)[1]

@test dparams ≈ grad_exact atol=1e-6
end
end
end
Loading