Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `average_states`
- `average_expect`
- `std_expect`
- Add support to ForwardDiff.jl for `sesolve` and `mesolve`. ([#515])
Copy link
Member

@ytdHuang ytdHuang Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the keep_runs_results log has many contents, I suggest we put this one before it.

This should increase the readability of the change log


## [v0.33.0]
Release date: 2025-07-22
Expand Down Expand Up @@ -284,3 +285,4 @@ Release date: 2024-11-13
[#509]: https://github.com/qutip/QuantumToolbox.jl/issues/509
[#512]: https://github.com/qutip/QuantumToolbox.jl/issues/512
[#513]: https://github.com/qutip/QuantumToolbox.jl/issues/513
[#515]: https://github.com/qutip/QuantumToolbox.jl/issues/515
4 changes: 2 additions & 2 deletions src/qobj/quantum_object_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,5 @@ _get_dims_length(::Space) = 1
_get_dims_length(::EnrSpace{N}) where {N} = N

# functions for getting Float or Complex element type
_FType(A::AbstractQuantumObject) = _FType(eltype(A))
_CType(A::AbstractQuantumObject) = _CType(eltype(A))
_float_type(A::AbstractQuantumObject) = _float_type(eltype(A))
_complex_float_type(A::AbstractQuantumObject) = _complex_float_type(eltype(A))
10 changes: 5 additions & 5 deletions src/spectrum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,17 @@ function _spectrum(
)
check_dimensions(L, A, B)

ωList = convert(Vector{_FType(L)}, ωlist) # Convert it to support GPUs and avoid type instabilities
ωList = convert(Vector{_float_type(L)}, ωlist) # Convert it to support GPUs and avoid type instabilities
Length = length(ωList)
spec = Vector{_FType(L)}(undef, Length)
spec = Vector{_float_type(L)}(undef, Length)

# calculate vectorized steadystate, multiply by operator B on the left (spre)
ρss = mat2vec(steadystate(L))
b = (spre(B) * ρss).data

# multiply by operator A on the left (spre) and then perform trace operation
D = prod(L.dimensions)
_tr = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(_CType(L), D)) # same as vec(system_identity_matrix)
_tr = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(_complex_float_type(L), D)) # same as vec(system_identity_matrix)
_tr_A = transpose(_tr) * spre(A).data

Id = I(D^2)
Expand Down Expand Up @@ -169,8 +169,8 @@ function _spectrum(
check_dimensions(L, A, B)

# Define type shortcuts
fT = _FType(L)
cT = _CType(L)
fT = _float_type(L)
cT = _complex_float_type(L)

# Calculate |v₁> = B|ρss>
ρss =
Expand Down
2 changes: 1 addition & 1 deletion src/steadystate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ function _steadystate(L::QuantumObject{SuperOperator}, solver::SteadyStateODESol
abstol = haskey(kwargs, :abstol) ? kwargs[:abstol] : DEFAULT_ODE_SOLVER_OPTIONS.abstol
reltol = haskey(kwargs, :reltol) ? kwargs[:reltol] : DEFAULT_ODE_SOLVER_OPTIONS.reltol

ftype = _FType(ψ0)
ftype = _float_type(ψ0)
_terminate_func = SteadyStateODECondition(similar(mat2vec(ket2dm(ψ0)).data))
cb = TerminateSteadyState(abstol, reltol, _terminate_func)
sol = mesolve(
Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/lr_mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ function lr_mesolveProblem(
c_ops = get_data.(c_ops)
e_ops = get_data.(e_ops)

t_l = _check_tlist(tlist, _FType(H))
t_l = _check_tlist(tlist, _float_type(H))

# Initialization of Arrays
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function mcsolveProblem(
c_ops isa Nothing &&
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))

tlist = _check_tlist(tlist, _FType(ψ0))
tlist = _check_tlist(tlist, _float_type(ψ0))

H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, c_ops)

Expand Down
6 changes: 3 additions & 3 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ function mesolveProblem(
haskey(kwargs, :save_idxs) &&
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

tlist = _check_tlist(tlist, _FType(ψ0))
tlist = _check_tlist(tlist, _float_type(ψ0))

L_evo = _mesolve_make_L_QobjEvo(H, c_ops)
check_dimensions(L_evo, ψ0)

T = Base.promote_eltype(L_evo, ψ0)
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
to_dense(_CType(T), copy(ψ0.data))
to_dense(_complex_float_type(T), copy(ψ0.data))
else
to_dense(_CType(T), mat2vec(ket2dm(ψ0).data))
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
end
L = L_evo.data

Expand Down
4 changes: 2 additions & 2 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ function sesolveProblem(
haskey(kwargs, :save_idxs) &&
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

tlist = _check_tlist(tlist, _FType(ψ0))
tlist = _check_tlist(tlist, _float_type(ψ0))

H_evo = _sesolve_make_U_QobjEvo(H) # Multiply by -i
isoper(H_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
check_dimensions(H_evo, ψ0)

T = Base.promote_eltype(H_evo, ψ0)
ψ0 = to_dense(_CType(T), get_data(ψ0)) # Convert it to dense vector with complex element type
ψ0 = to_dense(_complex_float_type(T), get_data(ψ0)) # Convert it to dense vector with complex element type
U = H_evo.data

kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions src/time_evolution/smesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ function smesolveProblem(
sc_ops_list = _make_c_ops_list(sc_ops) # If it is an AbstractQuantumObject but we need to iterate
sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject

tlist = _check_tlist(tlist, _FType(ψ0))
tlist = _check_tlist(tlist, _float_type(ψ0))

L_evo = _mesolve_make_L_QobjEvo(H, c_ops) + _mesolve_make_L_QobjEvo(nothing, sc_ops_list)
check_dimensions(L_evo, ψ0)
dims = L_evo.dimensions

T = Base.promote_eltype(L_evo, ψ0)
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
to_dense(_CType(T), copy(ψ0.data))
to_dense(_complex_float_type(T), copy(ψ0.data))
else
to_dense(_CType(T), mat2vec(ket2dm(ψ0).data))
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
end

progr = ProgressBar(length(tlist), enable = getVal(progress_bar))
Expand Down
4 changes: 2 additions & 2 deletions src/time_evolution/ssesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ function ssesolveProblem(
sc_ops_list = _make_c_ops_list(sc_ops) # If it is an AbstractQuantumObject but we need to iterate
sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject

tlist = _check_tlist(tlist, _FType(ψ0))
tlist = _check_tlist(tlist, _float_type(ψ0))

H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, sc_ops_list)
isoper(H_eff_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
check_dimensions(H_eff_evo, ψ0)
dims = H_eff_evo.dimensions

ψ0 = to_dense(_CType(ψ0), get_data(ψ0))
ψ0 = to_dense(_complex_float_type(ψ0), get_data(ψ0))

progr = ProgressBar(length(tlist), enable = getVal(progress_bar))

Expand Down
42 changes: 22 additions & 20 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ where ``\hbar`` is the reduced Planck constant, and ``k_B`` is the Boltzmann con
function n_thermal(ω::T1, ω_th::T2) where {T1<:Real,T2<:Real}
x = exp(ω / ω_th)
n = ((x != 1) && (ω_th > 0)) ? 1 / (x - 1) : 0
return _FType(promote_type(T1, T2))(n)
return _float_type(promote_type(T1, T2))(n)
end

@doc raw"""
Expand Down Expand Up @@ -125,7 +125,7 @@ julia> round(convert_unit(1, :meV, :mK), digits=4)
function convert_unit(value::T, unit1::Symbol, unit2::Symbol) where {T<:Real}
!haskey(_energy_units, unit1) && throw(ArgumentError("Invalid unit :$(unit1)"))
!haskey(_energy_units, unit2) && throw(ArgumentError("Invalid unit :$(unit2)"))
return _FType(T)(value * (_energy_units[unit1] / _energy_units[unit2]))
return _float_type(T)(value * (_energy_units[unit1] / _energy_units[unit2]))
end

get_typename_wrapper(A) = Base.typename(typeof(A)).wrapper
Expand Down Expand Up @@ -174,24 +174,26 @@ for AType in (:AbstractArray, :AbstractSciMLOperator)
end

# functions for getting Float or Complex element type
_FType(::AbstractArray{T}) where {T<:Number} = _FType(T)
_FType(::Type{Int32}) = Float32
_FType(::Type{Int64}) = Float64
_FType(::Type{Float32}) = Float32
_FType(::Type{Float64}) = Float64
_FType(::Type{Complex{Int32}}) = Float32
_FType(::Type{Complex{Int64}}) = Float64
_FType(::Type{Complex{Float32}}) = Float32
_FType(::Type{Complex{Float64}}) = Float64
_CType(::AbstractArray{T}) where {T<:Number} = _CType(T)
_CType(::Type{Int32}) = ComplexF32
_CType(::Type{Int64}) = ComplexF64
_CType(::Type{Float32}) = ComplexF32
_CType(::Type{Float64}) = ComplexF64
_CType(::Type{Complex{Int32}}) = ComplexF32
_CType(::Type{Complex{Int64}}) = ComplexF64
_CType(::Type{Complex{Float32}}) = ComplexF32
_CType(::Type{Complex{Float64}}) = ComplexF64
_float_type(::AbstractArray{T}) where {T<:Number} = _float_type(T)
_float_type(::Type{Int32}) = Float32
_float_type(::Type{Int64}) = Float64
_float_type(::Type{Float32}) = Float32
_float_type(::Type{Float64}) = Float64
_float_type(::Type{Complex{Int32}}) = Float32
_float_type(::Type{Complex{Int64}}) = Float64
_float_type(::Type{Complex{Float32}}) = Float32
_float_type(::Type{Complex{Float64}}) = Float64
_float_type(T::Type{<:Real}) = T # Allow other untracked Real types, like ForwardDiff.Dual
_complex_float_type(::AbstractArray{T}) where {T<:Number} = _complex_float_type(T)
_complex_float_type(::Type{Int32}) = ComplexF32
_complex_float_type(::Type{Int64}) = ComplexF64
_complex_float_type(::Type{Float32}) = ComplexF32
_complex_float_type(::Type{Float64}) = ComplexF64
_complex_float_type(::Type{Complex{Int32}}) = ComplexF32
_complex_float_type(::Type{Complex{Int64}}) = ComplexF64
_complex_float_type(::Type{Complex{Float32}}) = ComplexF32
_complex_float_type(::Type{Complex{Float64}}) = ComplexF64
_complex_float_type(T::Type{<:Complex}) = T # Allow other untracked Complex types, like ForwardDiff.Dual

_convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:Int} = Int64
_convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:Int} = Int32
Expand Down
3 changes: 2 additions & 1 deletion test/ext-test/cpu/autodiff/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
115 changes: 115 additions & 0 deletions test/ext-test/cpu/autodiff/autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
@testset "Autodiff" verbose=true begin
@testset "sesolve" verbose=true begin
ψ0 = fock(2, 1)
t_max = 10
tlist = range(0, t_max, 100)

# For direct Forward differentiation
function my_f_sesolve_direct(p)
H = p[1] * sigmax()
sol = sesolve(H, ψ0, tlist, progress_bar = Val(false))

return real(expect(projection(2, 0, 0), sol.states[end]))
end

# For SciMLSensitivity.jl
coef_Ω(p, t) = p[1]
H_evo = QobjEvo(sigmax(), coef_Ω)

function my_f_sesolve(p)
sol = sesolve(
H_evo,
ψ0,
tlist,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(projection(2, 0, 0), sol.states[end]))
end

# Analytical solution
my_f_analytic(Ω) = abs2(sin(Ω * t_max))
my_f_analytic_deriv(Ω) = 2 * t_max * sin(Ω * t_max) * cos(Ω * t_max)

Ω = 1.0
params = [Ω]

my_f_sesolve_direct(params)
my_f_sesolve(params)

grad_exact = [my_f_analytic_deriv(params[1])]

@testset "ForwardDiff.jl" begin
grad_qt = ForwardDiff.gradient(my_f_sesolve_direct, params)

@test grad_qt ≈ grad_exact atol=1e-6
end

@testset "Zygote.jl" begin
grad_qt = Zygote.gradient(my_f_sesolve, params)[1]

@test grad_qt ≈ grad_exact atol=1e-6
end
end

@testset "mesolve" verbose=true begin
N = 20
a = destroy(N)
ψ0 = fock(N, 0)
tlist = range(0, 40, 100)

# For direct Forward differentiation
function my_f_mesolve_direct(p)
H = p[1] * a' * a + p[2] * (a + a')
c_ops = [sqrt(p[3]) * a]
sol = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false))
return real(expect(a' * a, sol.states[end]))
end

# For SciMLSensitivity.jl
coef_Δ(p, t) = p[1]
coef_F(p, t) = p[2]
coef_γ(p, t) = sqrt(p[3])
H = QobjEvo(a' * a, coef_Δ) + QobjEvo(a + a', coef_F)
c_ops = [QobjEvo(a, coef_γ)]
L = liouvillian(H, c_ops)

function my_f_mesolve(p)
sol = mesolve(
L,
ψ0,
tlist,
progress_bar = Val(false),
params = p,
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
)

return real(expect(a' * a, sol.states[end]))
end

# Analytical solution
n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2))

Δ = 1.0
F = 1.0
γ = 1.0
params = [Δ, F, γ]

my_f_mesolve_direct(params)
my_f_mesolve(params)

grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1]

@testset "ForwardDiff.jl" begin
grad_qt = ForwardDiff.gradient(my_f_mesolve_direct, params)
@test grad_qt ≈ grad_exact atol=1e-6
end

@testset "Zygote.jl" begin
grad_qt = Zygote.gradient(my_f_mesolve, params)[1]
@test grad_qt ≈ grad_exact atol=1e-6
end
end
end
Loading
Loading