diff --git a/ext/QuantumToolboxCUDAExt.jl b/ext/QuantumToolboxCUDAExt.jl index e1b3418bb..7d379c3a1 100644 --- a/ext/QuantumToolboxCUDAExt.jl +++ b/ext/QuantumToolboxCUDAExt.jl @@ -89,4 +89,10 @@ _change_eltype(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32 _change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64 _change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32 +sparse_to_dense(::Type{T}, A::CuArray{T}) where {T<:Number} = A +sparse_to_dense(::Type{T1}, A::CuArray{T2}) where {T1<:Number,T2<:Number} = CuArray{T1}(A) +sparse_to_dense(::Type{T}, A::CuSparseVector) where {T<:Number} = CuArray{T}(A) +sparse_to_dense(::Type{T}, A::CuSparseMatrixCSC) where {T<:Number} = CuArray{T}(A) +sparse_to_dense(::Type{T}, A::CuSparseMatrixCSR) where {T<:Number} = CuArray{T}(A) + end diff --git a/src/qobj/functions.jl b/src/qobj/functions.jl index a69d8c108..de729d08c 100644 --- a/src/qobj/functions.jl +++ b/src/qobj/functions.jl @@ -108,12 +108,16 @@ variance(O::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject}, ψ::Vector Converts a sparse QuantumObject to a dense QuantumObject. """ sparse_to_dense(A::QuantumObject{<:AbstractVecOrMat}) = QuantumObject(sparse_to_dense(A.data), A.type, A.dims) -sparse_to_dense(A::MT) where {MT<:AbstractSparseMatrix} = Array(A) +sparse_to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A) for op in (:Transpose, :Adjoint) @eval sparse_to_dense(A::$op{T,<:AbstractSparseMatrix}) where {T<:BlasFloat} = Array(A) end sparse_to_dense(A::MT) where {MT<:AbstractArray} = A +sparse_to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A) +sparse_to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A) +sparse_to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A + function sparse_to_dense(::Type{M}) where {M<:SparseMatrixCSC} T = M par = T.parameters diff --git a/src/qobj/quantum_object.jl b/src/qobj/quantum_object.jl index 3a2ce8ed8..f55244b2b 100644 --- a/src/qobj/quantum_object.jl +++ b/src/qobj/quantum_object.jl @@ -368,3 +368,7 @@ SparseArrays.SparseMatrixCSC(A::QuantumObject{<:AbstractMatrix}) = QuantumObject(SparseMatrixCSC(A.data), A.type, A.dims) SparseArrays.SparseMatrixCSC{T}(A::QuantumObject{<:SparseMatrixCSC}) where {T<:Number} = QuantumObject(SparseMatrixCSC{T}(A.data), A.type, A.dims) + +# functions for getting Float or Complex element type +_FType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _FType(T) +_CType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _CType(T) diff --git a/src/steadystate.jl b/src/steadystate.jl index be06cbf1a..080f17281 100644 --- a/src/steadystate.jl +++ b/src/steadystate.jl @@ -95,10 +95,12 @@ function steadystate( (H.dims != ψ0.dims) && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension.")) N = prod(H.dims) - u0 = mat2vec(ket2dm(ψ0).data) + u0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) + L = MatrixOperator(liouvillian(H, c_ops).data) - prob = ODEProblem{true}(L, u0, (0.0, tspan)) + ftype = _FType(ψ0) + prob = ODEProblem{true}(L, u0, (ftype(0), ftype(tspan))) # Convert tspan to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl sol = solve( prob, solver.alg; @@ -109,7 +111,6 @@ function steadystate( ) ρss = reshape(sol.u[end], N, N) - ρss = (ρss + ρss') / 2 # Hermitianize return QuantumObject(ρss, Operator, H.dims) end diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index bd62bb410..34a722ba8 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -190,7 +190,7 @@ function mcsolveProblem( c_ops isa Nothing && throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead.")) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2 diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index f83a1bc3b..95a591f58 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -122,9 +122,9 @@ function mesolveProblem( is_time_dependent = !(H_t isa Nothing) progress_bar_val = makeVal(progress_bar) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type - ρ0 = mat2vec(ket2dm(ψ0).data) + t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl L = liouvillian(H, c_ops).data progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val)) diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index 6ef530d65..264a6527b 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -103,9 +103,9 @@ function sesolveProblem( is_time_dependent = !(H_t isa Nothing) progress_bar_val = makeVal(progress_bar) - t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl + ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type - ϕ0 = get_data(ψ0) + t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl U = -1im * get_data(H) progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val)) diff --git a/src/utilities.jl b/src/utilities.jl index df93872d4..a2888b70c 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -65,3 +65,19 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} = @warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" * join(arg, ", ") * ")` instead of `$argname = $arg`." maxlog = 1 + +# functions for getting Float or Complex element type +_FType(::AbstractArray{T}) where {T<:Number} = _FType(T) +_FType(::Type{Int32}) = Float32 +_FType(::Type{Int64}) = Float64 +_FType(::Type{Float32}) = Float32 +_FType(::Type{Float64}) = Float64 +_FType(::Type{ComplexF32}) = Float32 +_FType(::Type{ComplexF64}) = Float64 +_CType(::AbstractArray{T}) where {T<:Number} = _CType(T) +_CType(::Type{Int32}) = ComplexF32 +_CType(::Type{Int64}) = ComplexF64 +_CType(::Type{Float32}) = ComplexF32 +_CType(::Type{Float64}) = ComplexF64 +_CType(::Type{ComplexF32}) = ComplexF32 +_CType(::Type{ComplexF64}) = ComplexF64 diff --git a/test/core-test/time_evolution.jl b/test/core-test/time_evolution.jl index b4d097dea..f86d0bfe6 100644 --- a/test/core-test/time_evolution.jl +++ b/test/core-test/time_evolution.jl @@ -33,7 +33,9 @@ "reltol = $(sol.reltol)\n" @testset "Type Inference sesolve" begin - @inferred sesolveProblem(H, psi0, t_l) + @inferred sesolveProblem(H, psi0, t_l, progress_bar = Val(false)) + @inferred sesolveProblem(H, psi0, [0, 10], progress_bar = Val(false)) + @inferred sesolveProblem(H, Qobj(zeros(Int64, N * 2); dims = (N, 2)), t_l, progress_bar = Val(false)) @inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false)) @inferred sesolve(H, psi0, t_l, progress_bar = Val(false)) @inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false)) @@ -91,6 +93,8 @@ @testset "Type Inference mesolve" begin @inferred mesolveProblem(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false)) + @inferred mesolveProblem(H, psi0, [0, 10], c_ops, e_ops = e_ops, progress_bar = Val(false)) + @inferred mesolveProblem(H, Qobj(zeros(Int64, N)), t_l, c_ops, e_ops = e_ops, progress_bar = Val(false)) @inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false)) @inferred mesolve(H, psi0, t_l, c_ops, progress_bar = Val(false)) @inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = Val(false)) @@ -108,6 +112,8 @@ ) @inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false)) @inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, progress_bar = Val(true)) + @inferred mcsolve(H, psi0, [0, 10], c_ops, n_traj = 500, progress_bar = Val(false)) + @inferred mcsolve(H, Qobj(zeros(Int64, N)), t_l, c_ops, n_traj = 500, progress_bar = Val(false)) end end