diff --git a/CHANGELOG.md b/CHANGELOG.md index 8671bd3c8..21dbd00a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `average_states` - `average_expect` - `std_expect` +- Add support to ForwardDiff.jl for `sesolve` and `mesolve`. ([#515]) ## [v0.33.0] Release date: 2025-07-22 @@ -284,3 +285,4 @@ Release date: 2024-11-13 [#509]: https://github.com/qutip/QuantumToolbox.jl/issues/509 [#512]: https://github.com/qutip/QuantumToolbox.jl/issues/512 [#513]: https://github.com/qutip/QuantumToolbox.jl/issues/513 +[#515]: https://github.com/qutip/QuantumToolbox.jl/issues/515 diff --git a/src/qobj/quantum_object_base.jl b/src/qobj/quantum_object_base.jl index deff612c4..6ae1024d3 100644 --- a/src/qobj/quantum_object_base.jl +++ b/src/qobj/quantum_object_base.jl @@ -245,5 +245,5 @@ _get_dims_length(::Space) = 1 _get_dims_length(::EnrSpace{N}) where {N} = N # functions for getting Float or Complex element type -_FType(A::AbstractQuantumObject) = _FType(eltype(A)) -_CType(A::AbstractQuantumObject) = _CType(eltype(A)) +_float_type(A::AbstractQuantumObject) = _float_type(eltype(A)) +_complex_float_type(A::AbstractQuantumObject) = _complex_float_type(eltype(A)) diff --git a/src/spectrum.jl b/src/spectrum.jl index 9a6313ad4..c9473bc17 100644 --- a/src/spectrum.jl +++ b/src/spectrum.jl @@ -127,9 +127,9 @@ function _spectrum( ) check_dimensions(L, A, B) - ωList = convert(Vector{_FType(L)}, ωlist) # Convert it to support GPUs and avoid type instabilities + ωList = convert(Vector{_float_type(L)}, ωlist) # Convert it to support GPUs and avoid type instabilities Length = length(ωList) - spec = Vector{_FType(L)}(undef, Length) + spec = Vector{_float_type(L)}(undef, Length) # calculate vectorized steadystate, multiply by operator B on the left (spre) ρss = mat2vec(steadystate(L)) @@ -137,7 +137,7 @@ function _spectrum( # multiply by operator A on the left (spre) and then perform trace operation D = prod(L.dimensions) - _tr = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(_CType(L), D)) # same as vec(system_identity_matrix) + _tr = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(_complex_float_type(L), D)) # same as vec(system_identity_matrix) _tr_A = transpose(_tr) * spre(A).data Id = I(D^2) @@ -169,8 +169,8 @@ function _spectrum( check_dimensions(L, A, B) # Define type shortcuts - fT = _FType(L) - cT = _CType(L) + fT = _float_type(L) + cT = _complex_float_type(L) # Calculate |v₁> = B|ρss> ρss = diff --git a/src/steadystate.jl b/src/steadystate.jl index d17ad4783..a81fe3b3f 100644 --- a/src/steadystate.jl +++ b/src/steadystate.jl @@ -206,7 +206,7 @@ function _steadystate(L::QuantumObject{SuperOperator}, solver::SteadyStateODESol abstol = haskey(kwargs, :abstol) ? kwargs[:abstol] : DEFAULT_ODE_SOLVER_OPTIONS.abstol reltol = haskey(kwargs, :reltol) ? kwargs[:reltol] : DEFAULT_ODE_SOLVER_OPTIONS.reltol - ftype = _FType(ψ0) + ftype = _float_type(ψ0) _terminate_func = SteadyStateODECondition(similar(mat2vec(ket2dm(ψ0)).data)) cb = TerminateSteadyState(abstol, reltol, _terminate_func) sol = mesolve( diff --git a/src/time_evolution/lr_mesolve.jl b/src/time_evolution/lr_mesolve.jl index 1c73010e1..de49df963 100644 --- a/src/time_evolution/lr_mesolve.jl +++ b/src/time_evolution/lr_mesolve.jl @@ -412,7 +412,7 @@ function lr_mesolveProblem( c_ops = get_data.(c_ops) e_ops = get_data.(e_ops) - t_l = _check_tlist(tlist, _FType(H)) + t_l = _check_tlist(tlist, _float_type(H)) # Initialization of Arrays expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l)) diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index 647be42d4..ddd4379cb 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -125,7 +125,7 @@ function mcsolveProblem( c_ops isa Nothing && throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead.")) - tlist = _check_tlist(tlist, _FType(ψ0)) + tlist = _check_tlist(tlist, _float_type(ψ0)) H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, c_ops) diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index 9dfd22273..1ab033bd3 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -79,16 +79,16 @@ function mesolveProblem( haskey(kwargs, :save_idxs) && throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox.")) - tlist = _check_tlist(tlist, _FType(ψ0)) + tlist = _check_tlist(tlist, _float_type(ψ0)) L_evo = _mesolve_make_L_QobjEvo(H, c_ops) check_dimensions(L_evo, ψ0) T = Base.promote_eltype(L_evo, ψ0) ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type - to_dense(_CType(T), copy(ψ0.data)) + to_dense(_complex_float_type(T), copy(ψ0.data)) else - to_dense(_CType(T), mat2vec(ket2dm(ψ0).data)) + to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data)) end L = L_evo.data diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index cd635a705..b447307a3 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -61,14 +61,14 @@ function sesolveProblem( haskey(kwargs, :save_idxs) && throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox.")) - tlist = _check_tlist(tlist, _FType(ψ0)) + tlist = _check_tlist(tlist, _float_type(ψ0)) H_evo = _sesolve_make_U_QobjEvo(H) # Multiply by -i isoper(H_evo) || throw(ArgumentError("The Hamiltonian must be an Operator.")) check_dimensions(H_evo, ψ0) T = Base.promote_eltype(H_evo, ψ0) - ψ0 = to_dense(_CType(T), get_data(ψ0)) # Convert it to dense vector with complex element type + ψ0 = to_dense(_complex_float_type(T), get_data(ψ0)) # Convert it to dense vector with complex element type U = H_evo.data kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...) diff --git a/src/time_evolution/smesolve.jl b/src/time_evolution/smesolve.jl index 1d28d4e02..1f1aa3363 100644 --- a/src/time_evolution/smesolve.jl +++ b/src/time_evolution/smesolve.jl @@ -94,7 +94,7 @@ function smesolveProblem( sc_ops_list = _make_c_ops_list(sc_ops) # If it is an AbstractQuantumObject but we need to iterate sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject - tlist = _check_tlist(tlist, _FType(ψ0)) + tlist = _check_tlist(tlist, _float_type(ψ0)) L_evo = _mesolve_make_L_QobjEvo(H, c_ops) + _mesolve_make_L_QobjEvo(nothing, sc_ops_list) check_dimensions(L_evo, ψ0) @@ -102,9 +102,9 @@ function smesolveProblem( T = Base.promote_eltype(L_evo, ψ0) ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type - to_dense(_CType(T), copy(ψ0.data)) + to_dense(_complex_float_type(T), copy(ψ0.data)) else - to_dense(_CType(T), mat2vec(ket2dm(ψ0).data)) + to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data)) end progr = ProgressBar(length(tlist), enable = getVal(progress_bar)) diff --git a/src/time_evolution/ssesolve.jl b/src/time_evolution/ssesolve.jl index 487a3a06a..962a831ad 100644 --- a/src/time_evolution/ssesolve.jl +++ b/src/time_evolution/ssesolve.jl @@ -94,14 +94,14 @@ function ssesolveProblem( sc_ops_list = _make_c_ops_list(sc_ops) # If it is an AbstractQuantumObject but we need to iterate sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject - tlist = _check_tlist(tlist, _FType(ψ0)) + tlist = _check_tlist(tlist, _float_type(ψ0)) H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, sc_ops_list) isoper(H_eff_evo) || throw(ArgumentError("The Hamiltonian must be an Operator.")) check_dimensions(H_eff_evo, ψ0) dims = H_eff_evo.dimensions - ψ0 = to_dense(_CType(ψ0), get_data(ψ0)) + ψ0 = to_dense(_complex_float_type(ψ0), get_data(ψ0)) progr = ProgressBar(length(tlist), enable = getVal(progress_bar)) diff --git a/src/utilities.jl b/src/utilities.jl index ebe9481f4..e30d6785c 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -46,7 +46,7 @@ where ``\hbar`` is the reduced Planck constant, and ``k_B`` is the Boltzmann con function n_thermal(ω::T1, ω_th::T2) where {T1<:Real,T2<:Real} x = exp(ω / ω_th) n = ((x != 1) && (ω_th > 0)) ? 1 / (x - 1) : 0 - return _FType(promote_type(T1, T2))(n) + return _float_type(promote_type(T1, T2))(n) end @doc raw""" @@ -125,7 +125,7 @@ julia> round(convert_unit(1, :meV, :mK), digits=4) function convert_unit(value::T, unit1::Symbol, unit2::Symbol) where {T<:Real} !haskey(_energy_units, unit1) && throw(ArgumentError("Invalid unit :$(unit1)")) !haskey(_energy_units, unit2) && throw(ArgumentError("Invalid unit :$(unit2)")) - return _FType(T)(value * (_energy_units[unit1] / _energy_units[unit2])) + return _float_type(T)(value * (_energy_units[unit1] / _energy_units[unit2])) end get_typename_wrapper(A) = Base.typename(typeof(A)).wrapper @@ -174,24 +174,26 @@ for AType in (:AbstractArray, :AbstractSciMLOperator) end # functions for getting Float or Complex element type -_FType(::AbstractArray{T}) where {T<:Number} = _FType(T) -_FType(::Type{Int32}) = Float32 -_FType(::Type{Int64}) = Float64 -_FType(::Type{Float32}) = Float32 -_FType(::Type{Float64}) = Float64 -_FType(::Type{Complex{Int32}}) = Float32 -_FType(::Type{Complex{Int64}}) = Float64 -_FType(::Type{Complex{Float32}}) = Float32 -_FType(::Type{Complex{Float64}}) = Float64 -_CType(::AbstractArray{T}) where {T<:Number} = _CType(T) -_CType(::Type{Int32}) = ComplexF32 -_CType(::Type{Int64}) = ComplexF64 -_CType(::Type{Float32}) = ComplexF32 -_CType(::Type{Float64}) = ComplexF64 -_CType(::Type{Complex{Int32}}) = ComplexF32 -_CType(::Type{Complex{Int64}}) = ComplexF64 -_CType(::Type{Complex{Float32}}) = ComplexF32 -_CType(::Type{Complex{Float64}}) = ComplexF64 +_float_type(::AbstractArray{T}) where {T<:Number} = _float_type(T) +_float_type(::Type{Int32}) = Float32 +_float_type(::Type{Int64}) = Float64 +_float_type(::Type{Float32}) = Float32 +_float_type(::Type{Float64}) = Float64 +_float_type(::Type{Complex{Int32}}) = Float32 +_float_type(::Type{Complex{Int64}}) = Float64 +_float_type(::Type{Complex{Float32}}) = Float32 +_float_type(::Type{Complex{Float64}}) = Float64 +_float_type(T::Type{<:Real}) = T # Allow other untracked Real types, like ForwardDiff.Dual +_complex_float_type(::AbstractArray{T}) where {T<:Number} = _complex_float_type(T) +_complex_float_type(::Type{Int32}) = ComplexF32 +_complex_float_type(::Type{Int64}) = ComplexF64 +_complex_float_type(::Type{Float32}) = ComplexF32 +_complex_float_type(::Type{Float64}) = ComplexF64 +_complex_float_type(::Type{Complex{Int32}}) = ComplexF32 +_complex_float_type(::Type{Complex{Int64}}) = ComplexF64 +_complex_float_type(::Type{Complex{Float32}}) = ComplexF32 +_complex_float_type(::Type{Complex{Float64}}) = ComplexF64 +_complex_float_type(T::Type{<:Complex}) = T # Allow other untracked Complex types, like ForwardDiff.Dual _convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:Int} = Int64 _convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:Int} = Int32 diff --git a/test/ext-test/cpu/autodiff/Project.toml b/test/ext-test/cpu/autodiff/Project.toml index e07db5ee7..f52a45c2b 100644 --- a/test/ext-test/cpu/autodiff/Project.toml +++ b/test/ext-test/cpu/autodiff/Project.toml @@ -1,5 +1,6 @@ [deps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" \ No newline at end of file +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/ext-test/cpu/autodiff/autodiff.jl b/test/ext-test/cpu/autodiff/autodiff.jl new file mode 100644 index 000000000..57bf60b5b --- /dev/null +++ b/test/ext-test/cpu/autodiff/autodiff.jl @@ -0,0 +1,115 @@ +@testset "Autodiff" verbose=true begin + @testset "sesolve" verbose=true begin + ψ0 = fock(2, 1) + t_max = 10 + tlist = range(0, t_max, 100) + + # For direct Forward differentiation + function my_f_sesolve_direct(p) + H = p[1] * sigmax() + sol = sesolve(H, ψ0, tlist, progress_bar = Val(false)) + + return real(expect(projection(2, 0, 0), sol.states[end])) + end + + # For SciMLSensitivity.jl + coef_Ω(p, t) = p[1] + H_evo = QobjEvo(sigmax(), coef_Ω) + + function my_f_sesolve(p) + sol = sesolve( + H_evo, + ψ0, + tlist, + progress_bar = Val(false), + params = p, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + ) + + return real(expect(projection(2, 0, 0), sol.states[end])) + end + + # Analytical solution + my_f_analytic(Ω) = abs2(sin(Ω * t_max)) + my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max) + + Ω = 1.0 + params = [Ω] + + my_f_sesolve_direct(params) + my_f_sesolve(params) + + grad_exact = [my_f_analytic_deriv(params[1])] + + @testset "ForwardDiff.jl" begin + grad_qt = ForwardDiff.gradient(my_f_sesolve_direct, params) + + @test grad_qt ≈ grad_exact atol=1e-6 + end + + @testset "Zygote.jl" begin + grad_qt = Zygote.gradient(my_f_sesolve, params)[1] + + @test grad_qt ≈ grad_exact atol=1e-6 + end + end + + @testset "mesolve" verbose=true begin + N = 20 + a = destroy(N) + ψ0 = fock(N, 0) + tlist = range(0, 40, 100) + + # For direct Forward differentiation + function my_f_mesolve_direct(p) + H = p[1] * a' * a + p[2] * (a + a') + c_ops = [sqrt(p[3]) * a] + sol = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false)) + return real(expect(a' * a, sol.states[end])) + end + + # For SciMLSensitivity.jl + coef_Δ(p, t) = p[1] + coef_F(p, t) = p[2] + coef_γ(p, t) = sqrt(p[3]) + H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F) + c_ops = [QobjEvo(a, coef_γ)] + L = liouvillian(H, c_ops) + + function my_f_mesolve(p) + sol = mesolve( + L, + ψ0, + tlist, + progress_bar = Val(false), + params = p, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + ) + + return real(expect(a' * a, sol.states[end])) + end + + # Analytical solution + n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2)) + + Δ = 1.0 + F = 1.0 + γ = 1.0 + params = [Δ, F, γ] + + my_f_mesolve_direct(params) + my_f_mesolve(params) + + grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1] + + @testset "ForwardDiff.jl" begin + grad_qt = ForwardDiff.gradient(my_f_mesolve_direct, params) + @test grad_qt ≈ grad_exact atol=1e-6 + end + + @testset "Zygote.jl" begin + grad_qt = Zygote.gradient(my_f_mesolve, params)[1] + @test grad_qt ≈ grad_exact atol=1e-6 + end + end +end diff --git a/test/ext-test/cpu/autodiff/zygote.jl b/test/ext-test/cpu/autodiff/zygote.jl deleted file mode 100644 index bf2775a05..000000000 --- a/test/ext-test/cpu/autodiff/zygote.jl +++ /dev/null @@ -1,84 +0,0 @@ -@testset "Zygote Extension" verbose=true begin - @testset "sesolve" begin - coef_Ω(p, t) = p[1] - - H = QobjEvo(sigmax(), coef_Ω) - ψ0 = fock(2, 1) - t_max = 10 - - function my_f_sesolve(p) - tlist = range(0, t_max, 100) - - sol = sesolve( - H, - ψ0, - tlist, - progress_bar = Val(false), - params = p, - sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), - ) - - return real(expect(projection(2, 0, 0), sol.states[end])) - end - - # Analytical solution - my_f_analytic(Ω) = abs2(sin(Ω * t_max)) - my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max) - - Ω = 1.0 - params = [Ω] - - my_f_analytic(Ω) - my_f_sesolve(params) - - grad_qt = Zygote.gradient(my_f_sesolve, params)[1] - grad_exact = [my_f_analytic_deriv(params[1])] - - @test grad_qt ≈ grad_exact atol=1e-6 - end - - @testset "mesolve" begin - N = 20 - a = destroy(N) - - coef_Δ(p, t) = p[1] - coef_F(p, t) = p[2] - coef_γ(p, t) = sqrt(p[3]) - - H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F) - c_ops = [QobjEvo(a, coef_γ)] - L = liouvillian(H, c_ops) - - ψ0 = fock(N, 0) - - function my_f_mesolve(p) - tlist = range(0, 40, 100) - - sol = mesolve( - L, - ψ0, - tlist, - progress_bar = Val(false), - params = p, - sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), - ) - - return real(expect(a' * a, sol.states[end])) - end - - # Analytical solution - n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2)) - - Δ = 1.0 - F = 1.0 - γ = 1.0 - params = [Δ, F, γ] - - # The factor 2 is due to a bug - grad_qt = Zygote.gradient(my_f_mesolve, params)[1] - - grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1] - - @test grad_qt ≈ grad_exact atol=1e-6 - end -end diff --git a/test/runtests.jl b/test/runtests.jl index b650417cf..eab6514c8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,13 +42,14 @@ if (GROUP == "AutoDiff_Ext") Pkg.instantiate() using QuantumToolbox + using ForwardDiff using Zygote using Enzyme using SciMLSensitivity QuantumToolbox.about() - include(joinpath(testdir, "ext-test", "cpu", "autodiff", "zygote.jl")) + include(joinpath(testdir, "ext-test", "cpu", "autodiff", "autodiff.jl")) end if (GROUP == "Makie_Ext")