diff --git a/CHANGELOG.md b/CHANGELOG.md index 46b6bb5cd..6b4c0e42e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main) - Improve Bloch sphere rendering for animation. ([#520]) +- Add support to `Enzyme.jl` for `sesolve` and `mesolve`. ([#531]) ## [v0.34.0] Release date: 2025-07-29 @@ -295,3 +296,4 @@ Release date: 2024-11-13 [#515]: https://github.com/qutip/QuantumToolbox.jl/issues/515 [#517]: https://github.com/qutip/QuantumToolbox.jl/issues/517 [#520]: https://github.com/qutip/QuantumToolbox.jl/issues/520 +[#531]: https://github.com/qutip/QuantumToolbox.jl/issues/531 diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index 8e790eeea..65eac4eae 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -178,6 +178,11 @@ function mesolve( kwargs..., ) + # Move sensealg argument to solve for Enzyme.jl support. + # TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed. + sensealg = get(kwargs, :sensealg, nothing) + kwargs_filtered = isnothing(sensealg) ? kwargs : Base.structdiff((; kwargs...), (sensealg = sensealg,)) + prob = mesolveProblem( H, ψ0, @@ -188,14 +193,19 @@ function mesolve( params = params, progress_bar = progress_bar, inplace = inplace, - kwargs..., + kwargs_filtered..., ) - return mesolve(prob, alg) + # TODO: Remove sensealg when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed + if isnothing(sensealg) + return mesolve(prob, alg) + else + return mesolve(prob, alg; sensealg = sensealg) + end end -function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5()) - sol = solve(prob.prob, alg) +function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(); kwargs...) + sol = solve(prob.prob, alg; kwargs...) # No type instabilities since `isoperket` is a Val, and so it is known at compile time if getVal(prob.kwargs.isoperket) diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index c30302e7d..f7fd3f7c4 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -135,6 +135,12 @@ function sesolve( inplace::Union{Val,Bool} = Val(true), kwargs..., ) + + # Move sensealg argument to solve for Enzyme.jl support. + # TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed. + sensealg = get(kwargs, :sensealg, nothing) + kwargs_filtered = isnothing(sensealg) ? kwargs : Base.structdiff((; kwargs...), (sensealg = sensealg,)) + prob = sesolveProblem( H, ψ0, @@ -143,14 +149,19 @@ function sesolve( params = params, progress_bar = progress_bar, inplace = inplace, - kwargs..., + kwargs_filtered..., ) - return sesolve(prob, alg) + # TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed. + if isnothing(sensealg) + return sesolve(prob, alg) + else + return sesolve(prob, alg; sensealg = sensealg) + end end -function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5()) - sol = solve(prob.prob, alg) +function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(); kwargs...) + sol = solve(prob.prob, alg; kwargs...) ψt = map(ϕ -> QuantumObject(ϕ, type = Ket(), dims = prob.dimensions), sol.u) diff --git a/test/ext-test/cpu/autodiff/autodiff.jl b/test/ext-test/cpu/autodiff/autodiff.jl index 57bf60b5b..3f8e601a4 100644 --- a/test/ext-test/cpu/autodiff/autodiff.jl +++ b/test/ext-test/cpu/autodiff/autodiff.jl @@ -1,38 +1,77 @@ -@testset "Autodiff" verbose=true begin - @testset "sesolve" verbose=true begin - ψ0 = fock(2, 1) - t_max = 10 - tlist = range(0, t_max, 100) +# ---- SESOLVE ---- +const ψ0_sesolve = fock(2, 1) +t_max = 10 +const tlist_sesolve = range(0, t_max, 100) - # For direct Forward differentiation - function my_f_sesolve_direct(p) - H = p[1] * sigmax() - sol = sesolve(H, ψ0, tlist, progress_bar = Val(false)) +# For direct Forward differentiation +function my_f_sesolve_direct(p) + H = p[1] * sigmax() + sol = sesolve(H, ψ0_sesolve, tlist_sesolve, progress_bar = Val(false)) - return real(expect(projection(2, 0, 0), sol.states[end])) - end + return real(expect(projection(2, 0, 0), sol.states[end])) +end - # For SciMLSensitivity.jl - coef_Ω(p, t) = p[1] - H_evo = QobjEvo(sigmax(), coef_Ω) - - function my_f_sesolve(p) - sol = sesolve( - H_evo, - ψ0, - tlist, - progress_bar = Val(false), - params = p, - sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), - ) - - return real(expect(projection(2, 0, 0), sol.states[end])) - end +# For SciMLSensitivity.jl +coef_Ω(p, t) = p[1] +const H_evo = QobjEvo(sigmax(), coef_Ω) + +function my_f_sesolve(p) + sol = sesolve( + H_evo, + ψ0_sesolve, + tlist_sesolve, + progress_bar = Val(false), + params = p, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + ) + + return real(expect(projection(2, 0, 0), sol.states[end])) +end - # Analytical solution - my_f_analytic(Ω) = abs2(sin(Ω * t_max)) - my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max) +# Analytical solution +my_f_analytic(Ω) = abs2(sin(Ω * t_max)) +my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max) + +# ---- MESOLVE ---- +const N = 20 +const a = destroy(N) +const ψ0_mesolve = fock(N, 0) +const tlist_mesolve = range(0, 40, 100) + +# For direct Forward differentiation +function my_f_mesolve_direct(p) + H = p[1] * a' * a + p[2] * (a + a') + c_ops = [sqrt(p[3]) * a] + sol = mesolve(H, ψ0_mesolve, tlist_mesolve, c_ops, progress_bar = Val(false)) + return real(expect(a' * a, sol.states[end])) +end + +# For SciMLSensitivity.jl +coef_Δ(p, t) = p[1] +coef_F(p, t) = p[2] +coef_γ(p, t) = sqrt(p[3]) +H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F) +c_ops = [QobjEvo(a, coef_γ)] +const L = liouvillian(H, c_ops) + +function my_f_mesolve(p) + sol = mesolve( + L, + ψ0_mesolve, + tlist_mesolve, + progress_bar = Val(false), + params = p, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + ) + + return real(expect(a' * a, sol.states[end])) +end +# Analytical solution +n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2)) + +@testset "Autodiff" verbose=true begin + @testset "sesolve" verbose=true begin Ω = 1.0 params = [Ω] @@ -52,46 +91,21 @@ @test grad_qt ≈ grad_exact atol=1e-6 end - end - @testset "mesolve" verbose=true begin - N = 20 - a = destroy(N) - ψ0 = fock(N, 0) - tlist = range(0, 40, 100) - - # For direct Forward differentiation - function my_f_mesolve_direct(p) - H = p[1] * a' * a + p[2] * (a + a') - c_ops = [sqrt(p[3]) * a] - sol = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false)) - return real(expect(a' * a, sol.states[end])) - end + @testset "Enzyme.jl" begin + dparams = Enzyme.make_zero(params) + Enzyme.autodiff( + Enzyme.set_runtime_activity(Enzyme.Reverse), + my_f_sesolve, + Active, + Duplicated(params, dparams), + )[1] - # For SciMLSensitivity.jl - coef_Δ(p, t) = p[1] - coef_F(p, t) = p[2] - coef_γ(p, t) = sqrt(p[3]) - H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F) - c_ops = [QobjEvo(a, coef_γ)] - L = liouvillian(H, c_ops) - - function my_f_mesolve(p) - sol = mesolve( - L, - ψ0, - tlist, - progress_bar = Val(false), - params = p, - sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), - ) - - return real(expect(a' * a, sol.states[end])) + @test dparams ≈ grad_exact atol=1e-6 end + end - # Analytical solution - n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2)) - + @testset "mesolve" verbose=true begin Δ = 1.0 F = 1.0 γ = 1.0 @@ -111,5 +125,17 @@ grad_qt = Zygote.gradient(my_f_mesolve, params)[1] @test grad_qt ≈ grad_exact atol=1e-6 end + + @testset "Enzyme.jl" begin + dparams = Enzyme.make_zero(params) + Enzyme.autodiff( + Enzyme.set_runtime_activity(Enzyme.Reverse), + my_f_mesolve, + Active, + Duplicated(params, dparams), + )[1] + + @test dparams ≈ grad_exact atol=1e-6 + end end end