Skip to content

Commit 5af51a2

Browse files
authored
Fix type conversion of tlist in time evolution (#229)
2 parents fb48447 + 725fa0b commit 5af51a2

File tree

9 files changed

+47
-10
lines changed

9 files changed

+47
-10
lines changed

ext/QuantumToolboxCUDAExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,10 @@ _change_eltype(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
8989
_change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
9090
_change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32
9191

92+
sparse_to_dense(::Type{T}, A::CuArray{T}) where {T<:Number} = A
93+
sparse_to_dense(::Type{T1}, A::CuArray{T2}) where {T1<:Number,T2<:Number} = CuArray{T1}(A)
94+
sparse_to_dense(::Type{T}, A::CuSparseVector) where {T<:Number} = CuArray{T}(A)
95+
sparse_to_dense(::Type{T}, A::CuSparseMatrixCSC) where {T<:Number} = CuArray{T}(A)
96+
sparse_to_dense(::Type{T}, A::CuSparseMatrixCSR) where {T<:Number} = CuArray{T}(A)
97+
9298
end

src/qobj/functions.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,16 @@ variance(O::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject}, ψ::Vector
108108
Converts a sparse QuantumObject to a dense QuantumObject.
109109
"""
110110
sparse_to_dense(A::QuantumObject{<:AbstractVecOrMat}) = QuantumObject(sparse_to_dense(A.data), A.type, A.dims)
111-
sparse_to_dense(A::MT) where {MT<:AbstractSparseMatrix} = Array(A)
111+
sparse_to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A)
112112
for op in (:Transpose, :Adjoint)
113113
@eval sparse_to_dense(A::$op{T,<:AbstractSparseMatrix}) where {T<:BlasFloat} = Array(A)
114114
end
115115
sparse_to_dense(A::MT) where {MT<:AbstractArray} = A
116116

117+
sparse_to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A)
118+
sparse_to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A)
119+
sparse_to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A
120+
117121
function sparse_to_dense(::Type{M}) where {M<:SparseMatrixCSC}
118122
T = M
119123
par = T.parameters

src/qobj/quantum_object.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,7 @@ SparseArrays.SparseMatrixCSC(A::QuantumObject{<:AbstractMatrix}) =
368368
QuantumObject(SparseMatrixCSC(A.data), A.type, A.dims)
369369
SparseArrays.SparseMatrixCSC{T}(A::QuantumObject{<:SparseMatrixCSC}) where {T<:Number} =
370370
QuantumObject(SparseMatrixCSC{T}(A.data), A.type, A.dims)
371+
372+
# functions for getting Float or Complex element type
373+
_FType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _FType(T)
374+
_CType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _CType(T)

src/steadystate.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,12 @@ function steadystate(
9595
(H.dims != ψ0.dims) && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))
9696

9797
N = prod(H.dims)
98-
u0 = mat2vec(ket2dm(ψ0).data)
98+
u0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data))
99+
99100
L = MatrixOperator(liouvillian(H, c_ops).data)
100101

101-
prob = ODEProblem{true}(L, u0, (0.0, tspan))
102+
ftype = _FType(ψ0)
103+
prob = ODEProblem{true}(L, u0, (ftype(0), ftype(tspan))) # Convert tspan to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
102104
sol = solve(
103105
prob,
104106
solver.alg;
@@ -109,7 +111,6 @@ function steadystate(
109111
)
110112

111113
ρss = reshape(sol.u[end], N, N)
112-
ρss = (ρss + ρss') / 2 # Hermitianize
113114
return QuantumObject(ρss, Operator, H.dims)
114115
end
115116

src/time_evolution/mcsolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function mcsolveProblem(
190190
c_ops isa Nothing &&
191191
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
192192

193-
t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
193+
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
194194

195195
H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2
196196

src/time_evolution/mesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ function mesolveProblem(
122122
is_time_dependent = !(H_t isa Nothing)
123123
progress_bar_val = makeVal(progress_bar)
124124

125-
t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
125+
ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type
126126

127-
ρ0 = mat2vec(ket2dm(ψ0).data)
127+
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
128128

129129
L = liouvillian(H, c_ops).data
130130
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

src/time_evolution/sesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ function sesolveProblem(
103103
is_time_dependent = !(H_t isa Nothing)
104104
progress_bar_val = makeVal(progress_bar)
105105

106-
t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
106+
ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type
107107

108-
ϕ0 = get_data(ψ0)
108+
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
109109

110110
U = -1im * get_data(H)
111111
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

src/utilities.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,19 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
6565
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
6666
join(arg, ", ") *
6767
")` instead of `$argname = $arg`." maxlog = 1
68+
69+
# functions for getting Float or Complex element type
70+
_FType(::AbstractArray{T}) where {T<:Number} = _FType(T)
71+
_FType(::Type{Int32}) = Float32
72+
_FType(::Type{Int64}) = Float64
73+
_FType(::Type{Float32}) = Float32
74+
_FType(::Type{Float64}) = Float64
75+
_FType(::Type{ComplexF32}) = Float32
76+
_FType(::Type{ComplexF64}) = Float64
77+
_CType(::AbstractArray{T}) where {T<:Number} = _CType(T)
78+
_CType(::Type{Int32}) = ComplexF32
79+
_CType(::Type{Int64}) = ComplexF64
80+
_CType(::Type{Float32}) = ComplexF32
81+
_CType(::Type{Float64}) = ComplexF64
82+
_CType(::Type{ComplexF32}) = ComplexF32
83+
_CType(::Type{ComplexF64}) = ComplexF64

test/core-test/time_evolution.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
"reltol = $(sol.reltol)\n"
3434

3535
@testset "Type Inference sesolve" begin
36-
@inferred sesolveProblem(H, psi0, t_l)
36+
@inferred sesolveProblem(H, psi0, t_l, progress_bar = Val(false))
37+
@inferred sesolveProblem(H, psi0, [0, 10], progress_bar = Val(false))
38+
@inferred sesolveProblem(H, Qobj(zeros(Int64, N * 2); dims = (N, 2)), t_l, progress_bar = Val(false))
3739
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false))
3840
@inferred sesolve(H, psi0, t_l, progress_bar = Val(false))
3941
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
@@ -91,6 +93,8 @@
9193

9294
@testset "Type Inference mesolve" begin
9395
@inferred mesolveProblem(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
96+
@inferred mesolveProblem(H, psi0, [0, 10], c_ops, e_ops = e_ops, progress_bar = Val(false))
97+
@inferred mesolveProblem(H, Qobj(zeros(Int64, N)), t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
9498
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
9599
@inferred mesolve(H, psi0, t_l, c_ops, progress_bar = Val(false))
96100
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
@@ -108,6 +112,8 @@
108112
)
109113
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
110114
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, progress_bar = Val(true))
115+
@inferred mcsolve(H, psi0, [0, 10], c_ops, n_traj = 500, progress_bar = Val(false))
116+
@inferred mcsolve(H, Qobj(zeros(Int64, N)), t_l, c_ops, n_traj = 500, progress_bar = Val(false))
111117
end
112118
end
113119

0 commit comments

Comments
 (0)