Skip to content

Commit 6bca739

Browse files
committed
handle type of u0 with sparse_to_dense
1 parent 3f5f4fe commit 6bca739

File tree

10 files changed

+45
-16
lines changed

10 files changed

+45
-16
lines changed

ext/QuantumToolboxCUDAExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module QuantumToolboxCUDAExt
22

33
using QuantumToolbox
4-
import QuantumToolbox: _convert_u0
54
import CUDA: cu, CuArray
65
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR
76
import SparseArrays: SparseVector, SparseMatrixCSC
@@ -90,7 +89,10 @@ _change_eltype(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
9089
_change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
9190
_change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32
9291

93-
# make sure u0 in time evolution is dense vector and has complex element type
94-
_convert_u0(u0::Union{CuArray{T},CuSparseVector{T}}) where {T<:Number} = convert(CuArray{complex(T)}, u0)
92+
sparse_to_dense(::Type{T}, A::CuArray{T}) where {T<:Number} = A
93+
sparse_to_dense(::Type{T1}, A::CuArray{T2}) where {T1<:Number,T2<:Number} = CuArray{T1}(A)
94+
sparse_to_dense(::Type{T}, A::CuSparseVector) where {T<:Number} = CuArray{T}(A)
95+
sparse_to_dense(::Type{T}, A::CuSparseMatrixCSC) where {T<:Number} = CuArray{T}(A)
96+
sparse_to_dense(::Type{T}, A::CuSparseMatrixCSR) where {T<:Number} = CuArray{T}(A)
9597

9698
end

src/qobj/functions.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,16 @@ variance(O::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject}, ψ::Vector
108108
Converts a sparse QuantumObject to a dense QuantumObject.
109109
"""
110110
sparse_to_dense(A::QuantumObject{<:AbstractVecOrMat}) = QuantumObject(sparse_to_dense(A.data), A.type, A.dims)
111-
sparse_to_dense(A::MT) where {MT<:AbstractSparseMatrix} = Array(A)
111+
sparse_to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A)
112112
for op in (:Transpose, :Adjoint)
113113
@eval sparse_to_dense(A::$op{T,<:AbstractSparseMatrix}) where {T<:BlasFloat} = Array(A)
114114
end
115115
sparse_to_dense(A::MT) where {MT<:AbstractArray} = A
116116

117+
sparse_to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A)
118+
sparse_to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A)
119+
sparse_to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A
120+
117121
function sparse_to_dense(::Type{M}) where {M<:SparseMatrixCSC}
118122
T = M
119123
par = T.parameters

src/qobj/quantum_object.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,7 @@ SparseArrays.SparseMatrixCSC(A::QuantumObject{<:AbstractMatrix}) =
368368
QuantumObject(SparseMatrixCSC(A.data), A.type, A.dims)
369369
SparseArrays.SparseMatrixCSC{T}(A::QuantumObject{<:SparseMatrixCSC}) where {T<:Number} =
370370
QuantumObject(SparseMatrixCSC{T}(A.data), A.type, A.dims)
371+
372+
# functions for getting Float or Complex element type
373+
_FType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _FType(T)
374+
_CType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _CType(T)

src/steadystate.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ function steadystate(
9595
(H.dims != ψ0.dims) && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))
9696

9797
N = prod(H.dims)
98-
u0 = _convert_u0(mat2vec(ket2dm(ψ0).data))
98+
u0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data))
9999

100100
L = MatrixOperator(liouvillian(H, c_ops).data)
101101

102-
Ftype = real(eltype(u0))
103-
prob = ODEProblem{true}(L, u0, (Ftype(0), Ftype(tspan))) # Convert tspan to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
102+
ftype = _FType(ψ0)
103+
prob = ODEProblem{true}(L, u0, (ftype(0), ftype(tspan))) # Convert tspan to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
104104
sol = solve(
105105
prob,
106106
solver.alg;

src/time_evolution/mcsolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function mcsolveProblem(
190190
c_ops isa Nothing &&
191191
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
192192

193-
t_l = convert(Vector{real(eltype(ψ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
193+
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
194194

195195
H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2
196196

src/time_evolution/mesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ function mesolveProblem(
122122
is_time_dependent = !(H_t isa Nothing)
123123
progress_bar_val = makeVal(progress_bar)
124124

125-
ρ0 = _convert_u0(mat2vec(ket2dm(ψ0).data))
125+
ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type
126126

127-
t_l = convert(Vector{real(eltype(ρ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
127+
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
128128

129129
L = liouvillian(H, c_ops).data
130130
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

src/time_evolution/sesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ function sesolveProblem(
103103
is_time_dependent = !(H_t isa Nothing)
104104
progress_bar_val = makeVal(progress_bar)
105105

106-
ϕ0 = _convert_u0(get_data(ψ0))
106+
ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type
107107

108-
t_l = convert(Vector{real(eltype(ϕ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
108+
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
109109

110110
U = -1im * get_data(H)
111111
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

src/utilities.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,18 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
6666
join(arg, ", ") *
6767
")` instead of `$argname = $arg`." maxlog = 1
6868

69-
# make sure u0 in time evolution is dense vector and has complex element type
70-
_convert_u0(u0::AbstractVector{T}) where {T<:Number} = convert(Vector{complex(T)}, u0)
69+
# functions for getting Float or Complex element type
70+
_FType(::AbstractArray{T}) where {T<:Number} = _FType(T)
71+
_FType(::Type{Int32}) = Float32
72+
_FType(::Type{Int64}) = Float64
73+
_FType(::Type{Float32}) = Float32
74+
_FType(::Type{Float64}) = Float64
75+
_FType(::Type{ComplexF32}) = Float32
76+
_FType(::Type{ComplexF64}) = Float64
77+
_CType(::AbstractArray{T}) where {T<:Number} = _CType(T)
78+
_CType(::Type{Int32}) = ComplexF32
79+
_CType(::Type{Int64}) = ComplexF64
80+
_CType(::Type{Float32}) = ComplexF32
81+
_CType(::Type{Float64}) = ComplexF64
82+
_CType(::Type{ComplexF32}) = ComplexF32
83+
_CType(::Type{ComplexF64}) = ComplexF64

test/core-test/time_evolution.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
"reltol = $(sol.reltol)\n"
3434

3535
@testset "Type Inference sesolve" begin
36-
@inferred sesolveProblem(H, psi0, t_l)
36+
@inferred sesolveProblem(H, psi0, t_l, progress_bar = Val(false))
37+
@inferred sesolveProblem(H, psi0, [0, 10], progress_bar = Val(false))
38+
@inferred sesolveProblem(H, Qobj(zeros(Int64, N * 2); dims = (N, 2)), t_l, progress_bar = Val(false))
3739
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false))
3840
@inferred sesolve(H, psi0, t_l, progress_bar = Val(false))
3941
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
@@ -91,6 +93,8 @@
9193

9294
@testset "Type Inference mesolve" begin
9395
@inferred mesolveProblem(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
96+
@inferred mesolveProblem(H, psi0, [0, 10], c_ops, e_ops = e_ops, progress_bar = Val(false))
97+
@inferred mesolveProblem(H, Qobj(zeros(Int64, N)), t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
9498
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
9599
@inferred mesolve(H, psi0, t_l, c_ops, progress_bar = Val(false))
96100
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
@@ -108,6 +112,8 @@
108112
)
109113
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
110114
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, progress_bar = Val(true))
115+
@inferred mcsolve(H, psi0, [0, 10], c_ops, n_traj = 500, progress_bar = Val(false))
116+
@inferred mcsolve(H, Qobj(zeros(Int64, N)), t_l, c_ops, n_traj = 500, progress_bar = Val(false))
111117
end
112118
end
113119

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ if (GROUP == "All") || (GROUP == "Core")
3939
end
4040
end
4141

42-
if (GROUP == "CUDA_Ext")# || (GROUP == "All")
42+
if (GROUP == "CUDA_Ext") || (GROUP == "All")
4343
Pkg.add("CUDA")
4444
include(joinpath(testdir, "ext-test", "cuda_ext.jl"))
4545
end

0 commit comments

Comments
 (0)