From a6893ee70615231f18bd73e2b34ebc239213110b Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sun, 6 Apr 2025 21:08:01 +0200 Subject: [PATCH 1/9] Working mesolve --- .github/workflows/CI.yml | 12 +++----- Project.toml | 3 ++ ext/QuantumToolboxChainRulesCore.jl | 14 +++++++++ test/ext-test/autodiff/Project.toml | 5 +++ test/ext-test/autodiff/zygote.jl | 48 +++++++++++++++++++++++++++++ test/runtests.jl | 17 ++++++++-- 6 files changed, 90 insertions(+), 9 deletions(-) create mode 100644 ext/QuantumToolboxChainRulesCore.jl create mode 100644 test/ext-test/autodiff/Project.toml create mode 100644 test/ext-test/autodiff/zygote.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c32f88395..1e48375f1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -55,19 +55,17 @@ jobs: - 'Core' include: - # for core tests (intermediate versions) - # - version: '1.x' - # node: - # os: 'ubuntu-latest' - # arch: 'x64' - # group: 'Core' - # for extension tests - version: '1' node: os: 'ubuntu-latest' arch: 'x64' group: 'CairoMakie_Ext' + - version: '1' + node: + os: 'ubuntu-latest' + arch: 'x64' + group: 'AutoDiff_Ext' steps: - uses: actions/checkout@v4 diff --git a/Project.toml b/Project.toml index dbf8fe0ca..943b45e7d 100644 --- a/Project.toml +++ b/Project.toml @@ -30,16 +30,19 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [extensions] QuantumToolboxCUDAExt = "CUDA" QuantumToolboxCairoMakieExt = "CairoMakie" QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"] +QuantumToolboxChainRulesCore = "ChainRulesCore" [compat] ArrayInterface = "6, 7" CUDA = "5" CairoMakie = "0.12, 0.13" +ChainRulesCore = "1" DiffEqBase = "6" DiffEqCallbacks = "4.2.1 - 4" DiffEqNoiseProcess = "5" diff --git a/ext/QuantumToolboxChainRulesCore.jl b/ext/QuantumToolboxChainRulesCore.jl new file mode 100644 index 000000000..29789821b --- /dev/null +++ b/ext/QuantumToolboxChainRulesCore.jl @@ -0,0 +1,14 @@ +module QuantumToolboxChainRulesCore + +using LinearAlgebra +import QuantumToolbox: QuantumObject +using ChainRulesCore + +function ChainRulesCore.rrule(::Type{QuantumObject}, data, type, dimensions) + obj = QuantumObject(data, type, dimensions) + f_pullback(Δobj) = (NoTangent(), Δobj.data, NoTangent(), NoTangent()) + f_pullback(Δobj_data::AbstractArray) = (NoTangent(), Δobj_data, NoTangent(), NoTangent()) + return obj, f_pullback +end + +end diff --git a/test/ext-test/autodiff/Project.toml b/test/ext-test/autodiff/Project.toml new file mode 100644 index 000000000..314afe38b --- /dev/null +++ b/test/ext-test/autodiff/Project.toml @@ -0,0 +1,5 @@ +[deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/ext-test/autodiff/zygote.jl b/test/ext-test/autodiff/zygote.jl new file mode 100644 index 000000000..0554a14c9 --- /dev/null +++ b/test/ext-test/autodiff/zygote.jl @@ -0,0 +1,48 @@ +@testset "Zygote.jl Autodiff" begin + @testset "mesolve" begin + N = 16 + a = destroy(N) + + coef_Δ(p, t) = p[1] + coef_F(p, t) = p[2] + coef_γ(p, t) = sqrt(p[3]) + + H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F) + c_ops = [QobjEvo(a, coef_γ)] + L = liouvillian(H, c_ops) + + ψ0 = fock(N, 0) + + function my_f_mesolve(p) + tlist = range(0, 40, 100) + + sol = mesolve( + L, + ψ0, + tlist, + progress_bar = Val(false), + params = p, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + ) + + return real(expect(a' * a, sol.states[end])) + end + + # Analytical solution + n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2)) + + Δ = 1.0 + F = 1.0 + γ = 1.0 + params = [Δ, F, γ] + my_f_mesolve(params) + n_ss(Δ, F, γ) + + # The factor 2 is due to a bug + grad_qt = Zygote.gradient(my_f_mesolve, params)[1] ./ 2 + + grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1] + + @test grad_qt ≈ grad_exact atol=1e-6 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d8670d9d3..4a654bf6e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,7 +55,20 @@ if (GROUP == "All") || (GROUP == "Code-Quality") include(joinpath(testdir, "core-test", "code-quality", "code_quality.jl")) end -if (GROUP == "CairoMakie_Ext")# || (GROUP == "All") +if (GROUP == "AutoDiff_Ext") + Pkg.activate("ext-test/cpu/autodiff") + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + Pkg.instantiate() + + using QuantumToolbox + using Zygote + using Enzyme + using SciMLSensitivity + + include(joinpath(testdir, "ext-test", "autodiff", "zygote.jl")) +end + +if (GROUP == "CairoMakie_Ext") Pkg.activate("ext-test/cpu/cairomakie") Pkg.develop(PackageSpec(path = dirname(@__DIR__))) Pkg.instantiate() @@ -67,7 +80,7 @@ if (GROUP == "CairoMakie_Ext")# || (GROUP == "All") include(joinpath(testdir, "ext-test", "cpu", "cairomakie", "cairomakie_ext.jl")) end -if (GROUP == "CUDA_Ext")# || (GROUP == "All") +if (GROUP == "CUDA_Ext") Pkg.activate("ext-test/gpu") Pkg.develop(PackageSpec(path = dirname(@__DIR__))) Pkg.instantiate() From d135b8fd975e5901f648d739ef2803c5d37e69c9 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 8 Apr 2025 01:44:14 +0200 Subject: [PATCH 2/9] Add sesolve tests --- .github/workflows/CI.yml | 7 +++ Project.toml | 2 +- ....jl => QuantumToolboxChainRulesCoreExt.jl} | 6 +-- src/QuantumToolbox.jl | 22 +++++++++ src/qobj/quantum_object_evo.jl | 2 - test/ext-test/autodiff/zygote.jl | 45 +++++++++++++++++-- 6 files changed, 74 insertions(+), 10 deletions(-) rename ext/{QuantumToolboxChainRulesCore.jl => QuantumToolboxChainRulesCoreExt.jl} (67%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1e48375f1..f34372fb4 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -55,6 +55,13 @@ jobs: - 'Core' include: + # for core tests (intermediate versions) + # - version: '1.x' + # node: + # os: 'ubuntu-latest' + # arch: 'x64' + # group: 'Core' + # for extension tests - version: '1' node: diff --git a/Project.toml b/Project.toml index 943b45e7d..c88a668db 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" QuantumToolboxCUDAExt = "CUDA" QuantumToolboxCairoMakieExt = "CairoMakie" QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"] -QuantumToolboxChainRulesCore = "ChainRulesCore" +QuantumToolboxChainRulesCoreExt = "ChainRulesCore" [compat] ArrayInterface = "6, 7" diff --git a/ext/QuantumToolboxChainRulesCore.jl b/ext/QuantumToolboxChainRulesCoreExt.jl similarity index 67% rename from ext/QuantumToolboxChainRulesCore.jl rename to ext/QuantumToolboxChainRulesCoreExt.jl index 29789821b..968d2e674 100644 --- a/ext/QuantumToolboxChainRulesCore.jl +++ b/ext/QuantumToolboxChainRulesCoreExt.jl @@ -1,10 +1,10 @@ -module QuantumToolboxChainRulesCore +module QuantumToolboxChainRulesCoreExt using LinearAlgebra import QuantumToolbox: QuantumObject -using ChainRulesCore +import ChainRulesCore: rrule, NoTangent, Tangent -function ChainRulesCore.rrule(::Type{QuantumObject}, data, type, dimensions) +function rrule(::Type{QuantumObject}, data, type, dimensions) obj = QuantumObject(data, type, dimensions) f_pullback(Δobj) = (NoTangent(), Δobj.data, NoTangent(), NoTangent()) f_pullback(Δobj_data::AbstractArray) = (NoTangent(), Δobj_data, NoTangent(), NoTangent()) diff --git a/src/QuantumToolbox.jl b/src/QuantumToolbox.jl index 409bcfa54..feb0de047 100644 --- a/src/QuantumToolbox.jl +++ b/src/QuantumToolbox.jl @@ -74,6 +74,28 @@ export permute # SciMLOperators export cache_operator, iscached, isconstant +# TODO: To remove when https://github.com/SciML/SciMLOperators.jl/pull/264 is merged +SCALINGNUMBERTYPES = (:AbstractSciMLScalarOperator, :Number, :UniformScaling) +# Special cases for constant scalars. These simplify the structure when applicable +for T in SCALINGNUMBERTYPES[2:end] + @eval function Base.:*(α::$T, L::ScaledOperator) + isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L) + return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule + end + @eval function Base.:*(L::ScaledOperator, α::$T) + isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L) + return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule + end + @eval function Base.:*(α::$T, L::MatrixOperator) + isconstant(L) && return MatrixOperator(α * L.A) + return ScaledOperator(α, L) # Going back to the generic case + end + @eval function Base.:*(L::MatrixOperator, α::$T) + isconstant(L) && return MatrixOperator(α * L.A) + return ScaledOperator(α, L) # Going back to the generic case + end +end + # Utility include("utilities.jl") include("versioninfo.jl") diff --git a/src/qobj/quantum_object_evo.jl b/src/qobj/quantum_object_evo.jl index 4d97d90d4..2a2dbe2d8 100644 --- a/src/qobj/quantum_object_evo.jl +++ b/src/qobj/quantum_object_evo.jl @@ -397,7 +397,6 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution` ) op = :(op_func_list[$i][1]) - data_type = op_type.parameters[1] dims_expr = (dims_expr..., :($op.dimensions)) func_methods_expr = (func_methods_expr..., :(methods(op_func_list[$i][2], [Any, Real]))) # [Any, Real] means each func must accept 2 arguments if i == 1 @@ -409,7 +408,6 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution` (isoper(op_type) || issuper(op_type)) || throw(ArgumentError("The element must be a Operator or SuperOperator.")) - data_type = op_type.parameters[1] dims_expr = (dims_expr..., :(op_func_list[$i].dimensions)) if i == 1 first_op = :(op_func_list[$i]) diff --git a/test/ext-test/autodiff/zygote.jl b/test/ext-test/autodiff/zygote.jl index 0554a14c9..cc235c488 100644 --- a/test/ext-test/autodiff/zygote.jl +++ b/test/ext-test/autodiff/zygote.jl @@ -1,6 +1,45 @@ -@testset "Zygote.jl Autodiff" begin +@testset "Zygote.jl Autodiff" verbose=true begin + @testset "sesolve" begin + coef_Ω(p, t) = p[1] + + H = QobjEvo(sigmax(), coef_Ω) + ψ0 = fock(2, 1) + t_max = 10 + + function my_f_sesolve(p) + tlist = range(0, t_max, 100) + + sol = sesolve( + H, + ψ0, + tlist, + progress_bar = Val(false), + params = p, + sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()), + ) + + return real(expect(projection(2, 0, 0), sol.states[end])) + end + + # Analytical solution + my_f_analytic(Ω) = abs2(sin(Ω * t_max)) + my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max) + + Ω = 1.0 + params = [Ω] + + my_f_analytic(Ω) + my_f_sesolve(params) + + + grad_qt = Zygote.gradient(my_f_sesolve, params)[1] ./ 2 + grad_exact = [my_f_analytic_deriv(params[1])] + + @test grad_qt ≈ grad_exact atol=1e-6 + end + @testset "mesolve" begin - N = 16 + N = 20 a = destroy(N) coef_Δ(p, t) = p[1] @@ -35,8 +74,6 @@ F = 1.0 γ = 1.0 params = [Δ, F, γ] - my_f_mesolve(params) - n_ss(Δ, F, γ) # The factor 2 is due to a bug grad_qt = Zygote.gradient(my_f_mesolve, params)[1] ./ 2 From 4e5cd14ebce8b44e16cb1d71b3ea406903ad587d Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 8 Apr 2025 01:51:57 +0200 Subject: [PATCH 3/9] Temporary fix of problem generation --- src/time_evolution/mesolve.jl | 8 +++++++- src/time_evolution/sesolve.jl | 8 +++++++- test/ext-test/autodiff/zygote.jl | 4 ++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index f0796f0c4..268535aa8 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -87,7 +87,13 @@ function mesolveProblem( kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncMESolve) tspan = (tlist[1], tlist[end]) - prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs3...) + + # TODO: Remove this when https://github.com/SciML/SciMLSensitivity.jl/issues/1181 is fixed + if haskey(kwargs3, :sensealg) + prob = ODEProblem{getVal(inplace)}(L, ρ0, tspan, params; kwargs3...) + else + prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs3...) + end return TimeEvolutionProblem(prob, tlist, L_evo.dimensions, (isoperket = Val(isoperket(ψ0)),)) end diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index 571b1a396..66f1df343 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -73,7 +73,13 @@ function sesolveProblem( kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSESolve) tspan = (tlist[1], tlist[end]) - prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs3...) + + # TODO: Remove this when https://github.com/SciML/SciMLSensitivity.jl/issues/1181 is fixed + if haskey(kwargs3, :sensealg) + prob = ODEProblem{getVal(inplace)}(U, ψ0, tspan, params; kwargs3...) + else + prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs3...) + end return TimeEvolutionProblem(prob, tlist, H_evo.dimensions) end diff --git a/test/ext-test/autodiff/zygote.jl b/test/ext-test/autodiff/zygote.jl index cc235c488..54bfadd8c 100644 --- a/test/ext-test/autodiff/zygote.jl +++ b/test/ext-test/autodiff/zygote.jl @@ -32,7 +32,7 @@ my_f_sesolve(params) - grad_qt = Zygote.gradient(my_f_sesolve, params)[1] ./ 2 + grad_qt = Zygote.gradient(my_f_sesolve, params)[1] grad_exact = [my_f_analytic_deriv(params[1])] @test grad_qt ≈ grad_exact atol=1e-6 @@ -76,7 +76,7 @@ params = [Δ, F, γ] # The factor 2 is due to a bug - grad_qt = Zygote.gradient(my_f_mesolve, params)[1] ./ 2 + grad_qt = Zygote.gradient(my_f_mesolve, params)[1] grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1] From 33173eec98addfae3434dee116611414602476dd Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 8 Apr 2025 02:03:33 +0200 Subject: [PATCH 4/9] Make changelog and format files --- CHANGELOG.md | 2 ++ test/ext-test/autodiff/zygote.jl | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98a8398b3..8b3c81077 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Make CUDA conversion more general using Adapt.jl. ([#436], [#437]) - Make the generation of `fock` states non-mutating to support Zygote.jl. ([#438]) - Remove Reexport.jl from the dependencies. ([#443]) +- Add support for automatic differentiation for `sesolve` and `mesolve`. ([#440]) ## [v0.29.1] Release date: 2025-03-07 @@ -195,3 +196,4 @@ Release date: 2024-11-13 [#437]: https://github.com/qutip/QuantumToolbox.jl/issues/437 [#438]: https://github.com/qutip/QuantumToolbox.jl/issues/438 [#443]: https://github.com/qutip/QuantumToolbox.jl/issues/443 +[#440]: https://github.com/qutip/QuantumToolbox.jl/issues/440 diff --git a/test/ext-test/autodiff/zygote.jl b/test/ext-test/autodiff/zygote.jl index 54bfadd8c..824da7a12 100644 --- a/test/ext-test/autodiff/zygote.jl +++ b/test/ext-test/autodiff/zygote.jl @@ -31,13 +31,12 @@ my_f_analytic(Ω) my_f_sesolve(params) - grad_qt = Zygote.gradient(my_f_sesolve, params)[1] grad_exact = [my_f_analytic_deriv(params[1])] @test grad_qt ≈ grad_exact atol=1e-6 end - + @testset "mesolve" begin N = 20 a = destroy(N) From 8a8c9daf2e68837680fc5bacf2e094259acccb0b Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 8 Apr 2025 11:09:33 +0200 Subject: [PATCH 5/9] Move tests into cpu folder --- test/ext-test/{ => cpu}/autodiff/Project.toml | 0 test/ext-test/{ => cpu}/autodiff/zygote.jl | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename test/ext-test/{ => cpu}/autodiff/Project.toml (100%) rename test/ext-test/{ => cpu}/autodiff/zygote.jl (100%) diff --git a/test/ext-test/autodiff/Project.toml b/test/ext-test/cpu/autodiff/Project.toml similarity index 100% rename from test/ext-test/autodiff/Project.toml rename to test/ext-test/cpu/autodiff/Project.toml diff --git a/test/ext-test/autodiff/zygote.jl b/test/ext-test/cpu/autodiff/zygote.jl similarity index 100% rename from test/ext-test/autodiff/zygote.jl rename to test/ext-test/cpu/autodiff/zygote.jl From 2d0e5655d7e51373ed02570356ef8ba350955feb Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 8 Apr 2025 11:27:15 +0200 Subject: [PATCH 6/9] Fixe path error --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4a654bf6e..8b026b207 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -65,7 +65,7 @@ if (GROUP == "AutoDiff_Ext") using Enzyme using SciMLSensitivity - include(joinpath(testdir, "ext-test", "autodiff", "zygote.jl")) + include(joinpath(testdir, "ext-test", "cpu", "autodiff", "zygote.jl")) end if (GROUP == "CairoMakie_Ext") From a31280df5b70ad461d19819ec7719f35dc62feea Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Fri, 11 Apr 2025 12:50:05 +0200 Subject: [PATCH 7/9] Reorder changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b3c81077..e4b73f275 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -195,5 +195,5 @@ Release date: 2024-11-13 [#436]: https://github.com/qutip/QuantumToolbox.jl/issues/436 [#437]: https://github.com/qutip/QuantumToolbox.jl/issues/437 [#438]: https://github.com/qutip/QuantumToolbox.jl/issues/438 -[#443]: https://github.com/qutip/QuantumToolbox.jl/issues/443 [#440]: https://github.com/qutip/QuantumToolbox.jl/issues/440 +[#443]: https://github.com/qutip/QuantumToolbox.jl/issues/443 From 3395e7b9577227b354ea31b071c7b36bf9e12230 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sat, 12 Apr 2025 00:14:11 +0200 Subject: [PATCH 8/9] Make special methods internally --- src/QuantumToolbox.jl | 22 -------------------- src/qobj/quantum_object_evo.jl | 29 +++++++++++++++++++++------ test/core-test/quantum_objects_evo.jl | 2 +- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/src/QuantumToolbox.jl b/src/QuantumToolbox.jl index feb0de047..409bcfa54 100644 --- a/src/QuantumToolbox.jl +++ b/src/QuantumToolbox.jl @@ -74,28 +74,6 @@ export permute # SciMLOperators export cache_operator, iscached, isconstant -# TODO: To remove when https://github.com/SciML/SciMLOperators.jl/pull/264 is merged -SCALINGNUMBERTYPES = (:AbstractSciMLScalarOperator, :Number, :UniformScaling) -# Special cases for constant scalars. These simplify the structure when applicable -for T in SCALINGNUMBERTYPES[2:end] - @eval function Base.:*(α::$T, L::ScaledOperator) - isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L) - return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule - end - @eval function Base.:*(L::ScaledOperator, α::$T) - isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L) - return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule - end - @eval function Base.:*(α::$T, L::MatrixOperator) - isconstant(L) && return MatrixOperator(α * L.A) - return ScaledOperator(α, L) # Going back to the generic case - end - @eval function Base.:*(L::MatrixOperator, α::$T) - isconstant(L) && return MatrixOperator(α * L.A) - return ScaledOperator(α, L) # Going back to the generic case - end -end - # Utility include("utilities.jl") include("versioninfo.jl") diff --git a/src/qobj/quantum_object_evo.jl b/src/qobj/quantum_object_evo.jl index 2a2dbe2d8..9f9364462 100644 --- a/src/qobj/quantum_object_evo.jl +++ b/src/qobj/quantum_object_evo.jl @@ -360,7 +360,7 @@ function QuantumObjectEvolution( if α isa Nothing return QuantumObjectEvolution(op.data, type, op.dimensions) end - return QuantumObjectEvolution(α * op.data, type, op.dimensions) + return QuantumObjectEvolution(_promote_to_scimloperator(α, op.data), type, op.dimensions) end #= @@ -443,16 +443,33 @@ function _make_SciMLOperator(op_func::Tuple, α) T = eltype(op_func[1]) update_func = (a, u, p, t) -> op_func[2](p, t) if α isa Nothing - return ScalarOperator(zero(T), update_func) * MatrixOperator(op_func[1].data) + return ScalarOperator(zero(T), update_func) * _promote_to_scimloperator(op_func[1].data) end - return ScalarOperator(zero(T), update_func) * MatrixOperator(α * op_func[1].data) + return ScalarOperator(zero(T), update_func) * _promote_to_scimloperator(α, op_func[1].data) end -function _make_SciMLOperator(op::QuantumObject, α) +function _make_SciMLOperator(op::AbstractQuantumObject, α) if α isa Nothing - return MatrixOperator(op.data) + return _promote_to_scimloperator(op.data) end - return MatrixOperator(α * op.data) + return _promote_to_scimloperator(α, op.data) +end + +_promote_to_scimloperator(data::AbstractMatrix) = MatrixOperator(data) +_promote_to_scimloperator(data::AbstractSciMLOperator) = data +# TODO: The following special cases can be simplified after +# https://github.com/SciML/SciMLOperators.jl/pull/264 is merged +_promote_to_scimloperator(α::Number, data::AbstractMatrix) = MatrixOperator(α * data) +function _promote_to_scimloperator(α::Number, data::MatrixOperator) + isconstant(data) && return MatrixOperator(α * data.A) + return ScaledOperator(α, data) # Going back to the generic case +end +function _promote_to_scimloperator(α::Number, data::ScaledOperator) + isconstant(data.λ) && return ScaledOperator(α * data.λ, data.L) + return ScaledOperator(data.λ, _promote_to_scimloperator(α, data.L)) # Try to propagate the rule +end +function _promote_to_scimloperator(α::Number, data::AbstractSciMLOperator) + return α * data # Going back to the generic case end @doc raw""" diff --git a/test/core-test/quantum_objects_evo.jl b/test/core-test/quantum_objects_evo.jl index 5367599fe..1e923cf31 100644 --- a/test/core-test/quantum_objects_evo.jl +++ b/test/core-test/quantum_objects_evo.jl @@ -19,7 +19,7 @@ @test_throws DimensionMismatch QobjEvo(a, type = SuperOperator) ψ = fock(10, 3) - @test_throws TypeError QobjEvo(ψ) + @test_throws MethodError QobjEvo(ψ) end # unsupported type of dims From 35c158977b399779e4be9cd7d01d6c3e6e461418 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sat, 12 Apr 2025 07:08:33 +0200 Subject: [PATCH 9/9] Change order in deps --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c88a668db..87766530a 100644 --- a/Project.toml +++ b/Project.toml @@ -28,15 +28,15 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [extensions] -QuantumToolboxCUDAExt = "CUDA" QuantumToolboxCairoMakieExt = "CairoMakie" -QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"] QuantumToolboxChainRulesCoreExt = "ChainRulesCore" +QuantumToolboxCUDAExt = "CUDA" +QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"] [compat] ArrayInterface = "6, 7"