diff --git a/ext/QuantumToolboxCUDAExt.jl b/ext/QuantumToolboxCUDAExt.jl index b4fcf0fc2..11a764244 100644 --- a/ext/QuantumToolboxCUDAExt.jl +++ b/ext/QuantumToolboxCUDAExt.jl @@ -2,7 +2,7 @@ module QuantumToolboxCUDAExt using QuantumToolbox using QuantumToolbox: makeVal, getVal -import QuantumToolbox: _sparse_similar +import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize import CUDA: cu, CuArray, allowscalar import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR, AbstractCuSparseArray import SparseArrays: SparseVector, SparseMatrixCSC @@ -81,27 +81,20 @@ function cu(A::QuantumObject; word_size::Union{Val,Int} = Val(64)) return cu(A, makeVal(word_size)) end -cu(A::QuantumObject, word_size::Union{Val{32},Val{64}}) = CuArray{_change_eltype(eltype(A), word_size)}(A) +cu(A::QuantumObject, word_size::Union{Val{32},Val{64}}) = CuArray{_convert_eltype_wordsize(eltype(A), word_size)}(A) function cu( A::QuantumObject{ObjType,DimsType,<:SparseVector}, word_size::Union{Val{32},Val{64}}, ) where {ObjType<:QuantumObjectType,DimsType<:AbstractDimensions} - return CuSparseVector{_change_eltype(eltype(A), word_size)}(A) + return CuSparseVector{_convert_eltype_wordsize(eltype(A), word_size)}(A) end function cu( A::QuantumObject{ObjType,DimsType,<:SparseMatrixCSC}, word_size::Union{Val{32},Val{64}}, ) where {ObjType<:QuantumObjectType,DimsType<:AbstractDimensions} - return CuSparseMatrixCSC{_change_eltype(eltype(A), word_size)}(A) + return CuSparseMatrixCSC{_convert_eltype_wordsize(eltype(A), word_size)}(A) end -_change_eltype(::Type{T}, ::Val{64}) where {T<:Int} = Int64 -_change_eltype(::Type{T}, ::Val{32}) where {T<:Int} = Int32 -_change_eltype(::Type{T}, ::Val{64}) where {T<:AbstractFloat} = Float64 -_change_eltype(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32 -_change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64 -_change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32 - QuantumToolbox.to_dense(A::MT) where {MT<:AbstractCuSparseArray} = CuArray(A) QuantumToolbox.to_dense(::Type{T1}, A::CuArray{T2}) where {T1<:Number,T2<:Number} = CuArray{T1}(A) diff --git a/src/utilities.jl b/src/utilities.jl index 0a660995f..a3547cf61 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -190,3 +190,10 @@ _CType(::Type{Complex{Int32}}) = ComplexF32 _CType(::Type{Complex{Int64}}) = ComplexF64 _CType(::Type{Complex{Float32}}) = ComplexF32 _CType(::Type{Complex{Float64}}) = ComplexF64 + +_convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:Int} = Int64 +_convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:Int} = Int32 +_convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:AbstractFloat} = Float64 +_convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32 +_convert_eltype_wordsize(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64 +_convert_eltype_wordsize(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32