Skip to content

Commit 5ba859a

Browse files
Add support to ForwarDiff differentiation (#515)
1 parent 0bc7621 commit 5ba859a

File tree

15 files changed

+163
-126
lines changed

15 files changed

+163
-126
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
- `average_states`
2020
- `average_expect`
2121
- `std_expect`
22+
- Add support to ForwardDiff.jl for `sesolve` and `mesolve`. ([#515])
2223

2324
## [v0.33.0]
2425
Release date: 2025-07-22
@@ -284,3 +285,4 @@ Release date: 2024-11-13
284285
[#509]: https://github.com/qutip/QuantumToolbox.jl/issues/509
285286
[#512]: https://github.com/qutip/QuantumToolbox.jl/issues/512
286287
[#513]: https://github.com/qutip/QuantumToolbox.jl/issues/513
288+
[#515]: https://github.com/qutip/QuantumToolbox.jl/issues/515

src/qobj/quantum_object_base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,5 +245,5 @@ _get_dims_length(::Space) = 1
245245
_get_dims_length(::EnrSpace{N}) where {N} = N
246246

247247
# functions for getting Float or Complex element type
248-
_FType(A::AbstractQuantumObject) = _FType(eltype(A))
249-
_CType(A::AbstractQuantumObject) = _CType(eltype(A))
248+
_float_type(A::AbstractQuantumObject) = _float_type(eltype(A))
249+
_complex_float_type(A::AbstractQuantumObject) = _complex_float_type(eltype(A))

src/spectrum.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,17 @@ function _spectrum(
127127
)
128128
check_dimensions(L, A, B)
129129

130-
ωList = convert(Vector{_FType(L)}, ωlist) # Convert it to support GPUs and avoid type instabilities
130+
ωList = convert(Vector{_float_type(L)}, ωlist) # Convert it to support GPUs and avoid type instabilities
131131
Length = length(ωList)
132-
spec = Vector{_FType(L)}(undef, Length)
132+
spec = Vector{_float_type(L)}(undef, Length)
133133

134134
# calculate vectorized steadystate, multiply by operator B on the left (spre)
135135
ρss = mat2vec(steadystate(L))
136136
b = (spre(B) * ρss).data
137137

138138
# multiply by operator A on the left (spre) and then perform trace operation
139139
D = prod(L.dimensions)
140-
_tr = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(_CType(L), D)) # same as vec(system_identity_matrix)
140+
_tr = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(_complex_float_type(L), D)) # same as vec(system_identity_matrix)
141141
_tr_A = transpose(_tr) * spre(A).data
142142

143143
Id = I(D^2)
@@ -169,8 +169,8 @@ function _spectrum(
169169
check_dimensions(L, A, B)
170170

171171
# Define type shortcuts
172-
fT = _FType(L)
173-
cT = _CType(L)
172+
fT = _float_type(L)
173+
cT = _complex_float_type(L)
174174

175175
# Calculate |v₁> = B|ρss>
176176
ρss =

src/steadystate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ function _steadystate(L::QuantumObject{SuperOperator}, solver::SteadyStateODESol
206206
abstol = haskey(kwargs, :abstol) ? kwargs[:abstol] : DEFAULT_ODE_SOLVER_OPTIONS.abstol
207207
reltol = haskey(kwargs, :reltol) ? kwargs[:reltol] : DEFAULT_ODE_SOLVER_OPTIONS.reltol
208208

209-
ftype = _FType(ψ0)
209+
ftype = _float_type(ψ0)
210210
_terminate_func = SteadyStateODECondition(similar(mat2vec(ket2dm(ψ0)).data))
211211
cb = TerminateSteadyState(abstol, reltol, _terminate_func)
212212
sol = mesolve(

src/time_evolution/lr_mesolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ function lr_mesolveProblem(
412412
c_ops = get_data.(c_ops)
413413
e_ops = get_data.(e_ops)
414414

415-
t_l = _check_tlist(tlist, _FType(H))
415+
t_l = _check_tlist(tlist, _float_type(H))
416416

417417
# Initialization of Arrays
418418
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))

src/time_evolution/mcsolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function mcsolveProblem(
125125
c_ops isa Nothing &&
126126
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
127127

128-
tlist = _check_tlist(tlist, _FType(ψ0))
128+
tlist = _check_tlist(tlist, _float_type(ψ0))
129129

130130
H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, c_ops)
131131

src/time_evolution/mesolve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ function mesolveProblem(
7979
haskey(kwargs, :save_idxs) &&
8080
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
8181

82-
tlist = _check_tlist(tlist, _FType(ψ0))
82+
tlist = _check_tlist(tlist, _float_type(ψ0))
8383

8484
L_evo = _mesolve_make_L_QobjEvo(H, c_ops)
8585
check_dimensions(L_evo, ψ0)
8686

8787
T = Base.promote_eltype(L_evo, ψ0)
8888
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
89-
to_dense(_CType(T), copy(ψ0.data))
89+
to_dense(_complex_float_type(T), copy(ψ0.data))
9090
else
91-
to_dense(_CType(T), mat2vec(ket2dm(ψ0).data))
91+
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
9292
end
9393
L = L_evo.data
9494

src/time_evolution/sesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ function sesolveProblem(
6161
haskey(kwargs, :save_idxs) &&
6262
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
6363

64-
tlist = _check_tlist(tlist, _FType(ψ0))
64+
tlist = _check_tlist(tlist, _float_type(ψ0))
6565

6666
H_evo = _sesolve_make_U_QobjEvo(H) # Multiply by -i
6767
isoper(H_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
6868
check_dimensions(H_evo, ψ0)
6969

7070
T = Base.promote_eltype(H_evo, ψ0)
71-
ψ0 = to_dense(_CType(T), get_data(ψ0)) # Convert it to dense vector with complex element type
71+
ψ0 = to_dense(_complex_float_type(T), get_data(ψ0)) # Convert it to dense vector with complex element type
7272
U = H_evo.data
7373

7474
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)

src/time_evolution/smesolve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,17 @@ function smesolveProblem(
9494
sc_ops_list = _make_c_ops_list(sc_ops) # If it is an AbstractQuantumObject but we need to iterate
9595
sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject
9696

97-
tlist = _check_tlist(tlist, _FType(ψ0))
97+
tlist = _check_tlist(tlist, _float_type(ψ0))
9898

9999
L_evo = _mesolve_make_L_QobjEvo(H, c_ops) + _mesolve_make_L_QobjEvo(nothing, sc_ops_list)
100100
check_dimensions(L_evo, ψ0)
101101
dims = L_evo.dimensions
102102

103103
T = Base.promote_eltype(L_evo, ψ0)
104104
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
105-
to_dense(_CType(T), copy(ψ0.data))
105+
to_dense(_complex_float_type(T), copy(ψ0.data))
106106
else
107-
to_dense(_CType(T), mat2vec(ket2dm(ψ0).data))
107+
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
108108
end
109109

110110
progr = ProgressBar(length(tlist), enable = getVal(progress_bar))

src/time_evolution/ssesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ function ssesolveProblem(
9494
sc_ops_list = _make_c_ops_list(sc_ops) # If it is an AbstractQuantumObject but we need to iterate
9595
sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject
9696

97-
tlist = _check_tlist(tlist, _FType(ψ0))
97+
tlist = _check_tlist(tlist, _float_type(ψ0))
9898

9999
H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, sc_ops_list)
100100
isoper(H_eff_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
101101
check_dimensions(H_eff_evo, ψ0)
102102
dims = H_eff_evo.dimensions
103103

104-
ψ0 = to_dense(_CType(ψ0), get_data(ψ0))
104+
ψ0 = to_dense(_complex_float_type(ψ0), get_data(ψ0))
105105

106106
progr = ProgressBar(length(tlist), enable = getVal(progress_bar))
107107

0 commit comments

Comments
 (0)